entrpn / jax-nanoGPT

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jax-nanoGPT

A replicate nano-GPT in JAX.

Install

Install dependencies

pip install -r requirements.txt

If you want to use this code with TPUs, install:

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Train single node

To create a dataset run:

cd data/shakespeare
python prepare.py

This will create a train.bin and val.bin which holds GPT2 BPE token ids in one sequence. Now you can train. Go back to the folder with the training script and run.

python train.py --config shakespeare

Train multi node in GCP cloud

We can scale our training by using TPU pod slices and TPU-VMs. In short, we deploy multiple workers and execute the training job on each worker and let pmap handle scaling.

  1. We'll be using TPU-v4. which requires a subnet in the zone us-central2-b. Follow the instructions for Set up and prepare a Google Cloud project.

  2. Create an instance. Change your_project_id to yours.

    export TPU_NAME=tpu-v4
    export ZONE=us-central2-b
    export RUNTIME_VERSION=tpu-vm-v4-base
    export PROJECT_ID=<your_project_id>
    export ACCELERATOR_TYPE=v4-16
    
    gcloud compute tpus tpu-vm create ${TPU_NAME} \
    --zone us-central2-b \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --version ${RUNTIME_VERSION} \
    --subnetwork=tpusubnet \
    --network=tpu-network
  3. In order to ssh into the machine, you might need to modify ~/.ssh/config. Change <your_user_name> with your computer's use name (echo ~/) add the following:

    Host tpu-v4
    HostName 107.167.173.130
    IdentityFile /Users/<your_user_name>/.ssh/google_compute_engine
  4. As a test try to ssh. If this works, you're ready to move to the next steps.

    gcloud compute tpus tpu-vm ssh tpu-v4 --worker=0 --zone us-central2-b --project $PROJECT_ID
  5. Now we’ll run a training job on multiple machines. First, install jax[tpu], clone the repository on all machines and install dependencies

    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
    
    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="git clone https://github.com/entrpn/jax-nanoGPT.git"
    
    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="pip install -r jax-nanoGPT/requirements.txt"
  6. Generate the dataset in all devices - (TODO : generate data on single drive and mount it to all instances)

    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="python3 jax-nanoGPT/data/openwebtext-10k/prepare.py"
  7. Kick off training.

    gcloud compute tpus tpu-vm ssh tpu-v4 --zone  us-central2-b --project $PROJECT_ID --worker=all --command="cd jax-nanoGPT; python3 train.py --config openwebtext-10k"

Generate

To generate text, use the generate.py script with the config that was used for training and the last checkpoint step that was saved.

python generate.py --config shakespeare --checkpoint-step 7500

Tensorboard logs will be stored in out-{dataset-name} with train/eval loss, learning rate and sampled generations.

Examples

Training with openwebtext10k dataset for 25k steps, where the last 50 characters in the text are generated.



About

License:Apache License 2.0


Languages

Language:Python 100.0%