kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Resolving dependency issues

rinapch opened this issue · comments

There has been a number of issues regarding different version conflicts and how to fix them. I've spent some time trying to make this code run, so maybe this instruction will spare someone else their efforts :)

First of all, as per this issue in jax repo google/jax#13321, TPU VMs no longer work with jax older than 0.2.16. This repo requires jax==0.2.12. I found out that the code still works with jax versions 0.2.18 and 0.2.20

Additionally, since there are a number of dependecies in the requirements file that do not state the needed versions, I rolled back all of them to the lastest versions per January 2022 and used poetry to resolve conflics. Here is the pyproject.toml file in the end:

python = "^3.8"
numpy = ">=1.19.5,<1.20.0"
tqdm = ">=4.45.0,<4.46.0"
wandb = "^0.13.7"
einops = ">=0.3.0,<0.4.0"
requests = ">=2.25.1,<2.26.0"
fabric = ">=2.6.0,<2.7.0"
optax = "0.0.9"
dm-haiku = "0.0.5"
ray = {version = "1.4.1", extras = ["default"]}
jax = "0.2.18"
cloudpickle = ">=1.3.0,<1.4.0"
tensorflow-cpu = ">=2.6.0,<2.7.0"
google-cloud-storage = ">=1.36.2,<1.37.0"
transformers = ">=4.16.2,<4.17.0"
smart-open = {version = ">=5.2.1,<5.3.0", extras = ["gcs"]}
ftfy = ">=6.1,<7.0"
lm-dataformat = "^0.0.20"
pathy = "^0.10.1"
func-timeout = "^4.3.5"
chex = "0.0.5"

After installing all of this with poetry, install jax[tpu] with pip, so that it gets the right libtpu nightly build (pip install "jax[tpu]==0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html)

When starting training you also can experience training being stuck at validation. As was suggested by @versae in this issue #218, it helps to change TPU runtime version to an alpha build. Something like gcloud alpha compute tpus tpu-vm create gptj --accelerator-type v3-8 --version v2-alpha

With Colab Pro, the default TPU lib (and JAX) is now at 0.3.25. I jumped thru these hoops as well and have run with

!pip install mesh-transformer-jax/ jax==0.3.15 tensorflow==2.8.2 chex==0.1.4 jaxlib==0.3.15

Your mileage may vary..

Johnny

@rinapch Worked perfect for me. Thank you very much..

@rinapch Worked perfect for me. Thank you very much..

That said, it worked perfectly for fine tuning but not to infer on colab. (It caused optax error)
In order to set up the model, I needed to reverse the requirements as

numpy~=1.19.5
tqdm~=4.45.0
wandb>=0.11.2
einops~=0.3.0
requests~=2.25.1
fabric~=2.6.0
optax==0.0.6
git+https://github.com/deepmind/dm-haiku
git+https://github.com/EleutherAI/lm-evaluation-harness@c406a62047
ray[default]==1.4.1
jax~=0.2.12
Flask~=1.1.2
cloudpickle~=1.3.0
tensorflow-cpu~=2.5.0
google-cloud-storage~=1.36.2
transformers
smart_open[gcs]
func_timeout
ftfy
fastapi
uvicorn
lm_dataformat
pathy

and

!pip install chex==0.1.2
!pip install jaxlib==0.1.68
!pip install dm-haiku==0.0.5

Just as a note.

Thank you so much for this post, it helped me resolve all of my dependency issues. I have never worked with poetry before, but I was able to get a model training in a conda environment just using install commands.

If anybody is interested, I wrote out the steps I took from scratch that are currently working based on my test run.

-- First, Install conda on the TPU vm

mkdir conda_install
cd conda_install
sudo apt-get update
sudo apt-get install wget
wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh
bash Anaconda3-2022.10-Linux-x86_64.sh

-- Update path to include conda

export PATH=~/anaconda3/bin:$PATH

-- Create env with mamba and python == 3.8

conda create -n gpt -c conda-forge mamba python==3.8

-- Close and reopen terminal, ressh

gcloud compute tpus tpu-vm ssh YOUR_TPU_NAME --zone YOUR_ZONE_NAME

-- Leave base

conda deactivate 

-- Enter env

conda activate gpt

-- Install requirements available through conda first

mamba install -c conda-forge numpy==1.19.5 tqdm==4.45.0 einops==0.3.0 requests==2.25.1 fabric==2.6.0 optax==0.0.9 dm-haiku==0.0.5 jax==0.2.18 cloudpickle==1.3.0 tensorflow-cpu==2.6.0 google-cloud-storage==1.36.2 transformers==4.16.2 smart_open==5.2.1 ftfy==6.1.1 pathy==0.10.1 func_timeout==4.3.5

-- Install remaining requirements not available through conda with pip

pip install ray[default]==1.4.1 wandb==0.13.7 chex==0.0.5 lm-dataformat==0.0.20 typing-extensions==4.2.0 protobuf==3.19.5

-- NOTE: You will see a typing-extensions error pop up about tensorflow 2.6.0 not being compatible with 4.2.0. This is fine, ignore it.

-- Jax 0.2.12 does NOT WORK with TPUs anymore, but we can use 0.2.18 or 0.2.20

pip install "jax[tpu]==0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

-- If you have issues with protobuf (may originate from the import wandb call), run this

python3 -m pip uninstall protobuf
python3 -m pip install protobuf==3.19.5

-- Finally, you can run this and fine-tune your model

cd ./mesh-transformer-jax/
python3 device_train.py --config=./configs/YOUR_CONFIG_NAME.json --tune-model-path=gs://YOUR_BUCKET_NAME/step_383500/

@JohnnyOpcode How did you infer with JAX 0.3.15? I think it runs only with 0.2.12.

@JohnnyOpcode How did you infer with JAX 0.3.15? I think it runs only with 0.2.12.

I was using Colab Pro (paid) and I experimented with different versions of the libraries and with pip. The key takeaway is compatibility with the TPUv2 ASIC. I'll try and find some time to go thru those motions again and come up with a newer working requirements.txt for everybody.

Python sucks btw. Just like JS and TS. Too many brittle dependencies, but it does create lots of BS positions and salaries.