kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: Cannot subclass <class 'typing._SpecialForm'> while fine tuning

samyakai opened this issue · comments

I am trying to fine tune gpt-j on custom data using TPU. When I try to run the "device_train.py" file using the mentioned command: "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/", I get this error:

Traceback (most recent call last):
File "device_train.py", line 13, in
from mesh_transformer import util
File "/home/shreyjain/mesh-transformer-jax/mesh_transformer/util.py", line 36, in
class ClipByGlobalNormState(OptState):
File "/usr/lib/python3.8/typing.py", line 317, in new
raise TypeError(f"Cannot subclass {cls!r}")
TypeError: Cannot subclass <class 'typing._SpecialForm'>

OS = Ubuntu 20.04
TPU V3-8
python version = 3.8 and 3.7 both give the error

I have no idea what this error means. Any help would be appreciated!
Thank you.

Getting the same issue not able to solve
Please help us
Thank you

Maybe this works.

pip install dm-haiku==0.0.5
and put optax back to the default version.

@mosmos6 Nope it doesn't work

#202 (comment)
Please follow this solution! It works.

@anon-mouse-1 which v2 version of TPU should i use? There are 2 options for TPU namely TPU VM architecture and tpu node architecture.

commented

After the 5th error I just gave up on this notebook.

@Tylersuard Yes. Also no one is providing a solution to the errors which is a shame as I really want to train on TPU as opposed to a GPU

Downgrading optax worked for me to get rid of this error.

pip install optax==0.0.9

#202 (comment) Please follow this solution! It works.

In addition to this,
pip install chex==0.1.2
pip install jaxlib==0.1.74
pip install dm-haiku==0.0.5

and it worked for me.