ReAct architecture and training loop - now in Jax!
This is the runner script for the Docker container. It pulls the latest version of the code from the dev
branch, and runs train_model.py
with the arguments specified in TRAIN_ARGS
.
Thus you can easily modify the arguments in the below codeblock, and save the updated file somewhere. Everytime you run it, it would pull the latest git version on BRANCH
.
Run below script with elevated permissions!
sudo
#!/bin/bash
BRANCH="dev"
IMAGE_NAME="docker.io/neel04/react_image:latest"
CONTAINER_NAME="react_container"
# arguments for train_model.py
TRAIN_ARGS="--save_dir ./ReAct/outputs/ --epochs 4 --warmup_steps 250 \
--lr 3.5e-3 --num_blocks 4 \
--width 128 --batch_size 512 --n_heads 4 --max_iters 5 \
--weight_decay 1e-4 --drop_rate 0.02 \
--log_interval 1000 --save_interval 1000 --seqlen 192 \
--bf16 --wandb"
git clone -b $BRANCH https://github.com/neel04/ReAct_Jax.git
git config --global safe.directory '*'
git pull --all
# Stop all running Docker containers
echo "Stopping all running Docker containers..."
sudo docker stop $(sudo docker ps -a -q)
sudo -s <<EOF
# Git stuff
git config --global safe.directory '*'
# Run the Docker container
echo "Running Docker container..."
docker run --pull 'always' -v $(pwd)/ReAct_Jax/:/ReAct_Jax/ -e EQX_ON_ERROR=nan --privileged --rm --net=host --name $CONTAINER_NAME -it -d $IMAGE_NAME
# Get docker container ID to copy files
CONTAINER_ID=$(docker ps -aqf "name=$CONTAINER_NAME")
docker cp $(pwd)/ReAct_Jax $CONTAINER_ID:/
export JAX_TRACEBACK_FILTERING=off
# Execute train_model.py inside the Docker container
echo "Executing train_model.py inside Docker container..."
docker exec --privileged $CONTAINER_NAME git config --global safe.directory '*'
docker exec --privileged $CONTAINER_NAME python3 train_model.py $TRAIN_ARGS
EOF
echo "Finished training!"
python3 inferencer.py --checkpoint_path '/Users/neel/Documents/research/ReAct_Jax/ReAct/outputs/model 5000.eqx' --num_blocks 3 --width 256 --n_heads 4 --seqlen 196 --prompt "Sam is sad because"
Getting a preemptible TPUv4-8 node
gcloud alpha compute tpus queued-resources create node-v4 \
--node-id node-v4 \
--project react-jax \
--zone us-central2-b \
--accelerator-type v4-8 \
--runtime-version tpu-vm-v4-base \
--metadata-from-file startup-script=./run.sh \
--best-effort