kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TPU-V4

wimjan123 opened this issue · comments

How can one use this project to fine-tune using a TPU-v4 instance?
I tried everything, but always get errors.
Most commonly:

UserWarning: cloud_tpu_init failed: KeyError('v4-8')
This a JAX bug; please report an issue at https://github.com/google/jax/issues
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
2023-03-05 21:55:43.305762: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-03-05 21:55:43.941977: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-05 21:55:43.942070: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-03-05 21:55:43.942076: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "device_train.py", line 191, in
raise ValueError(msg)
ValueError: each shard needs a separate device, but device count (1) < shard count (4)

your jax version installed is wrong for your tpu version. (this repo is old)
basically you have to keep trying installations and images (i use image v2-alpha on TPUv3-8)
once this command works, then you have jax installed on your tpu working fine.

python3 -c "import jax; print(jax.devices())"  # should print TpuDevice

also, your libcudart errors means you need to uninstall your tensorflow and install tensorflow-cpu as you do not have a GPU on a TPU device.

i would recommend you go through https://github.com/ayaka14732/tpu-starter it can help with some errors you face.

I use V2-alpha-tpu4 on TPUv4-8.
The command to check if jax is installed returns this:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

According to the following, the number of TPU cores has changed from 8 to 4 for TPU v4.

Display the number of TPU cores available:
jax.device_count()
The number of TPU cores is displayed. If you are using a v4 TPU, this should be 4. If you are using a v2 or v3 TPU, this should be 8.

(source: https://cloud.google.com/tpu/docs/run-calculation-jax)

Aha, I see. Is there any way to fine tune gpt-j using 4 tpu cores?

Aha, I see. Is there any way to fine tune gpt-j using 4 tpu cores?

I change the following from 8 to 4 in the configuration file.

"cores_per_replica": 4

If I do that, I get a "AssertionError: Incompatible checkpoints" error

If I do that, I get a "AssertionError: Incompatible checkpoints" error

I forgot to mention that it's for pre-training from scratch. The above compatibility seems a valid issue since it's not clear whether the checkpoints on 8 cores can work on 4 cores.

Is there any way to convert the checkpoints to, let's say, 4 shards?

Is there any way to convert the checkpoints to, let's say, 4 shards?

No idea but I guess not and didn't try. I plan to move forward to TPU v4.

I'm curious how this attempt turned out. Has anyone succeeded in running GPT-J on TPU v4?