kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to infer with GPT-J on TPU_driver0.2 or nightly?

mosmos6 opened this issue · comments

Following the issues 252 and this, practically GPT-J model became unavailable on colab TPU with TPU_driver0.1 anymore. However, as default, GPT-J crashes on other drivers including 0.2 and nightly.
Is there any way to use 0.2 or nightly driver? Otherwise, it means GPT-J is ended on TPU inference.

I resolved this by my own so I'm sharing the modified mesh-transformer-jax with everyone.

Background;
In early March 2023, Google removed TPU_driver0.1 from colab. The original GPT-J strictly requires JAX 0.2.12 so it could not be inferred with on colab anymore because TPU_driver0.2 needs newer jax.

Takeaway;
I added some modifications to mesh_transformer folder and colab demo together with the updated requirements. Thus you can infer with this on colab now. You can continue to use the same (slim) weights as before.

How;
I uploaded the relevant file and folder here. You can extract mesh_transformer folder and requirement.txt file, replace them with the originals in your own repo, and use GPT-J inference on TPU_driver0.2.ipynb to infer with.

Important notes;

  1. Sorry, you'll need pro or pro+ subscription of colab because it requires high memory TPU runtime.

  2. I have not checked it for finetuning on TPU VM yet. This can cause errors during a process. I'm planning to cover it next month. Until then, possibly you must add further modifications to xmap by yourself or downgrade to jax 0.2.18 or 0.2.20.

  3. You can also infer with GPT-J by device_serve.py on TPU VM, but you can't use the original file. If anyone wants it, please post a request from issue.

Screenshot 2023-04-13 122413