josiahdavis / repgpt

Reproducing GPT2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reproduce GPT2 in half a day

  • Train GPT2-124M down to ~2.85 cross entropy loss on a single 8xA100 node in 12 hours.
  • ~500 lines of code in three source files.

alt text

How to Use

You will need access to a single node of 8xA100 (40GB memory) with CUDA drivers installed.

One time setup:

git clone https://github.com/josiahdavis/repgpt.git
cd repgpt
conda create -yn repgpt python=3.10
conda activate repgpt
pip install -r requirements.txt

Example runtime on a single 8xA100 node.

torchrun --standalone --nproc_per_node=8 src/repgpt/train.py --max_steps 90000

Features

  • PyTorch 2.1 which supports model compilation and flash attention
  • Automatic Mixed Precision (AMP) using BFloat16.
  • Distributed Data Parallelism (DDP).
  • Gradient Accumulation.
  • Logging in Tensorboard and loguru.
  • Compatible with SageMaker training job and EC2.
  • Reference guides explaining key concepts.

Reference guides

  • 01_data.ipynb: Understanding of the data we are feeding into the model.
  • 02_attention.ipynb: Gain an understanding of the attention mechanism, and reproduce PyTorch's attention function with vanilla matrix multiplication.
  • 03_loss.ipynb: Explainer for cross entropy loss.
  • 04_transformer.ipynb: Standalone explainer for the full transformer architecture.
  • 05_training.ipynb: Notes on key concepts and implementation in training like AMP, DDP, and DDP with Gradient Accumulation.
  • 06_logging.ipynb: Explainer for logging.

Links

FAQ

Q: Did you get it right the first time?

No, here is a summary of my training log:

  1. First run: loss never got below 3. I realized I had a bug in the learning rate scheduler.
  2. Second run: loss never dropped below 3 again, but I realized I didn't use the paper's initialization.
  3. Third run: I got a much better result of 2.89 with the new correct initialization, but still not quite there.
  4. Fourth run: I discovered that the original author's didn't use dropout, so I turned it off, but then I started getting gradient explosion.
  5. Fifth run: I turned off automatic mixed precision (AMP), and I was able to get to the goal of ~2.85 validation loss.
  6. I did a bunch of debugging with mixed precision trying to figure out what was causing the gradient explosion. As it turned out, when using the bfloat16 data format, you don't need to perform loss scaling, which fixed it.
  7. After getting AMP to work, I was able to reproduce the same training result in roughly half the time.
  8. Tried some light-weight hyperparameter optimization, increasing the learning rate and weight decay and got the training down to 24 hours link.
  9. Did some additional experimentation, and got the training down to 12 hours, which is where I stopped link.

I live-tweeted my experience here.

Q: What was your learning & development process?

  1. Read through foundational GPT1, GPT2 papers by @AlecRad et al...
  2. Went through @karpathy's video lecture on GPT.
  3. Looked up variety of online implementations: Harvard's annotated transformer, openAI, nanogpt, huggingface, @benjamin_warner, etc...
  4. Reproduced the attention function from PyTorch with vanilla matrix multiplication.
  5. Created a script to poll for a p4 instance/8xA100 (~19 days 😓).
  6. Implemented training engineering stuff (e.g., DDP, AMP, gradient accumulation, logging, etc...).
  7. Ran training multiple times, debugging issues including: learning rate scheduler bug, initialization issue, removing dropout, fixing gradient explosion.
  8. After getting it working, sped up training from 6.5 days down to 12 hours.

About

Reproducing GPT2

License:MIT License


Languages

Language:Jupyter Notebook 76.0%Language:Python 24.0%