minimaxir / aitextgen

A robust Python tool for text-based AI training and generation using GPT-2.

Home Page:https://docs.aitextgen.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TPU Support

minimaxir opened this issue · comments

Although you can train an aitextgen model on TPUs by setting n_tpu_cores=8 in an appropriate runtime, and the training loss indeed does decrease, there are a number of miscellaneous blocking problems:

  • The model stored in aitextgen does not update, even after training.
  • Saving the model via save_pretrained() causes hang, even with xm.rendezvous()
  • Memory leaks on the host system (especially with large batch size)
  • fp16 doesn't work at all, and there's no training loss decrease.

Will gladly take any suggestions/PRs to help resolve these!

I've been off and on with PTL, mainly because I haven't found easy adapters for larger custom datasets that TPUs can handle, but there's two areas that I would look at.

In your code, you call save_pytorch_model as a function with pl_module.model.save_pretrained, which I think PTL uses primarily for end-of-training rather than in the loop checkpointing.

https://pytorch-lightning.readthedocs.io/en/latest/callbacks.html#model-checkpointing callbacks seem to work more smoothly within PTL, especially with the distributed training functions (like TPUs and multi-GPUs)

Another notebook I came across while trying to debug TPUs was this one which built a custom training loop and function purely through torch_xla rather than a wrapper.

a few things from the notebook would be:

as global

import multiprocessing
_LOAD_LOCK = multiprocessing.Lock()

within _mp_fn:

device = xm.xla_device()
    with _LOAD_LOCK:
        _MODEL.to(device)
    xm.master_print('done loading model')

within the training loop:

xm.save(_MODEL.state_dict(), args.checkpoint_path)

I might tackle it here soon, just started playing with the library. Nice job on the tokenizer and dataloader. It's two of the things that need improvement imo in the main transformers repo. I did notice that memory spiked up to around 8gb while processing a 500mb txt file line-by-line through the dataloader. I haven't seen any really good implementations for batching larger datasets to fit into memory (5GB+). Any ideas around that?

The hang is possibly a pytorch-lightning issue: Lightning-AI/pytorch-lightning#2498

We now have CI on TPUs so we can fully detect bugs now.

Some things we’ve resolved:

  1. If you train with only a train loop and model checkpoint enabled, the training will hang. We’re currently looking into this.

  2. If you train on colab or kaggle you have to wait for the training to finish loading the model weights at the end. The reason is that the model is trained on a subprocess but not on the main one. This means that we ned to save weights and rebuild the model in the main process. This looks like a “hang” at the end of training but it is in fact saving weights and loading them back on the main process.

  3. As mentioned earlier, we didn’t have CI before with TPUs, so we might have had unknown bugs. However, we’ve been adding tests this week to fortify the TPU code. Expect the next release in a few days to stabilize TPU training more thoroughly.

we’ve been adding tests this week to fortify the TPU code.
Expect the next release in a few days to stabilize TPU training more thoroughly.

Very impressive & any recent progress? It will be very helpful if TPU works for aitextgen.

hey your work is too good. I was playing with your library some how i manage to start training on tpu its decrease the loos but when i save model its saved and then hang. it hang when it have to continue training. I done the same thing with out lightning works well.
You can see Following notebook where its run training with your library on tpu.
https://colab.research.google.com/drive/1Rmc6I_Xlkq9yE0Yw67RJ88nGtghNAP51?usp=sharing

After further testing today, the model gets one training step on each of the TPU cores...then freezes (including taking a few of @kbrajwani 's approaches which did not make a difference). The stack trace indicates that it's Pytorch-lightning's fault and the TPU CI there is apparently failing, so will wait a bit until that is resolved.

@minimaxir mind sharing a notebook with this issue?

I was just trying to figure out what I was doing wrong today, its good to know this is an existing issue. I had hoped to leverage colab TPU's

@jeremytech123 @minimaxir I am not able to get the notebooks mentioned above to run on colab. I get an error
ModuleNotFoundError: No module named 'transformers.convert_gpt2_original_tf_checkpoint_to_pytorch'
Any chance you could share your notebook? I could look into the TPU issue.

I'm sorry, I've already changed it a dozen times since then and I don't have the errors saved for review. I think it may have something to do with the fact I was trying to do training at the time

Gave it a second try: it worked, but at about P100 speeds for 1 TPU core (which a T4 + FP16 beats), and oddly for 8 TPU cores it resulted in the same performance as 1 TPU core.

And there's a bit of weirdness in getting it from CPU to TPU.

I dunno how much is in pytorch-lighthning and how much is in aitextgen and transformers though. But if it's not using all 8 cores then it's not worth developing.

Rolling into #97