bpucla / vit_10b_fsdp

See details in https://github.com/pytorch/xla/blob/r1.12/torch_xla/distributed/fsdp/README.md

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Vision Transformer (ViT) model using PyTorch/XLA FSDP

This repo implements sharded training of a Vision Transformer (ViT) model on a 10-billion parameter scale using the FSDP algorithm in PyTorch/XLA. It is now officially supported in the PyTorch/XLA 1.12 release.


Installation

  1. Allocate a v3-128 TPU VM pod (e.g. with name rh-128-0 in zone europe-west4-a) from the tpu-vm-pt-1.12 environment as follows according to TPU VM instruction. You can also try out larger TPU pods such as v3-256 or v3-512.
TPU_NAME=sfr-b-pang-tpu-32-us-east1-1  # change to your TPU name
ZONE=us-east1-d  # change to your TPU zone
ACCELERATOR_TYPE=v3-32  # you can also try out larger TPU pods
RUNTIME_VERSION=tpu-vm-pt-1.12  # the XLA FSDP interface is supported in PyTorch/XLA

gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
  --zone ${ZONE} \
  --accelerator-type ${ACCELERATOR_TYPE} \
  --version ${RUNTIME_VERSION}
  1. Install timm as a dependency (to create vision transformer layers) and clone this repository to all TPU VM nodes as follows.
TPU_NAME=sfr-b-pang-tpu-32-us-east1-1  # change to your TPU name
ZONE=us-east1-d  # change to your TPU zone

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} \
  --worker all \
  --command "
# ViT dependency
sudo pip3 install timm==0.4.12

# clone this repo ViT FSDP example
cd ~ && rm -rf vit_10b_fsdp_example && git clone https://github.com/bpucla/vit_10b_fsdp
"
  1. Download ImageNet-1k to a shared directory (e.g. to /datasets/imagenet-1k) that can be accessed from all nodes, which should have the following structure (the validation images moved to labeled subfolders, following the PyTorch ImageNet example).
/datasets/imagenet-1k
|_ train
|  |_ <n0......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-N-name>.JPEG
|  |_ ...
|  |_ <n1......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-M-name>.JPEG
|  |  |_...
|  |  |_...
|_ val
|  |_ <n0......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-N-name>.JPEG
|  |_ ...
|  |_ <n1......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-M-name>.JPEG
|  |  |_...
|  |  |_...

You can use a Persistent Disk or a Filestore NFS on GCP to store the ImageNet-1k dataset.

Also, you can also use --fake_data to run on fake datasets (dummy images filled with all zeros) as an alternative way to test the model.

Running the experiments

  1. Now log into your TPU VM.
TPU_NAME=rh-128-0  # change to your TPU name
ZONE=europe-west4-a  # change to your TPU zone

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --worker 0
  1. Before running any experiments, first set up the gcloud ssh configuration on your TPM VM as follows (only need to do it once):
cd ${HOME} && gcloud compute config-ssh --quiet
  1. Now we can run the experiments. For example, to train a ViT model with 10 billion parameters (5120 embed dim, 32 attention heads, 32 layers, and an MLP ratio of 4.0 that gives 20480 = 5120 * 4.0 feed-forward MLP dim), you can launch the following in a tmux session.
TPU_NAME=sfr-b-pang-tpu-32-us-east1-1  # change to your TPU name
SAVE_DIR=~/vit_10b_fsdp_example_ckpts  # this can be any directory (it doesn't need to be a shared one across nodes)

mkdir -p ${SAVE_DIR}
cd ${HOME} && python3 -m torch_xla.distributed.xla_dist \
  --tpu=${TPU_NAME} --restart-tpuvm-pod-server --env PYTHONUNBUFFERED=1 -- \
python3 -u ~/vit_10b_fsdp_example/run_vit_training.py \
  --data_dir /datasets/imagenet-1k \
  --ckpt_dir ${SAVE_DIR} \
  --image_size 224 \
  --patch_size 14 \
  --embed_dim 5120 \
  --mlp_ratio 4.0 \
  --num_heads 32 \
  --num_blocks 32 \
  --batch_size 1024 \
  --num_epochs 300 \
  --lr 1e-3 \
  --weight_decay 0.1 \
  --clip_grad_norm 1.0 \
  --warmup_steps 10000 \
  --log_step_interval 20 \
  2>&1 | tee ${SAVE_DIR}/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

Note that these hyperparameters (e.g. learning rate) are not necessarily optimal and you may need to tweak them to get the best performance. You can also use --fake_data to run on fake datasets (dummy images filled with all zeros). As a comparison, you can pass --run_without_fsdp to launch without FSDP, which can only fit much smaller model sizes.

You can also try running on models larger than the 10 billion size above. In general, you will need more TPU cores to fit more parameters. Don't worry if you see messages like tcmalloc: large alloc 1677729792 bytes == 0x181ff4000 when trying to run this codebase on even larger models (e.g. 60B parameters) -- this message is not an error. You can get rid of it by passing --env TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=4294967296 in torch_xla.distributed.xla_dist to raise the tcmalloc report threshold to e.g. 4 GB.

About

See details in https://github.com/pytorch/xla/blob/r1.12/torch_xla/distributed/fsdp/README.md


Languages

Language:Python 100.0%