kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: __init__() takes 2 positional arguments but 4 were given

opened this issue · comments

Traceback (most recent call last):
  File "/Users/macos/Desktop/AI/serve.py", line 9, in <module>
    from mesh_transformer.checkpoint import read_ckpt
  File "/Users/macos/opt/anaconda3/lib/python3.9/site-packages/mesh_transformer/checkpoint.py", line 14, in <module>
    from mesh_transformer.util import head_print
  File "/Users/macos/opt/anaconda3/lib/python3.9/site-packages/mesh_transformer/util.py", line 36, in <module>
    class ClipByGlobalNormState(OptState):
TypeError: __init__() takes 2 positional arguments but 4 were given

Your version of jax is too updated. Try to use the dependencies listed in #246