kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`TypeError: Cannot subclass <class 'typing._SpecialForm'>` in `slim_model.py `

danyaljj opened this issue · comments

Any ideas why I am getting this error?

$ python slim_model.py 
Traceback (most recent call last):
  File "slim_model.py", line 9, in <module>
    from mesh_transformer import util
  File "/home/danielk/mesh-transformer-jax/mesh_transformer/util.py", line 36, in <module>
    class ClipByGlobalNormState(OptState):
  File "/home/danielk/anaconda3/envs/jax_py38/lib/python3.8/typing.py", line 317, in __new__
    raise TypeError(f"Cannot subclass {cls!r}")
TypeError: Cannot subclass <class 'typing._SpecialForm'>

Here is my environment:

$ pip list
Package           Version
----------------- ---------
absl-py           1.0.0
certifi           2021.10.8
chex              0.1.1
dm-tree           0.1.6
flatbuffers       2.0
jax               0.2.12
jaxlib            0.1.68
numpy             1.22.3
opt-einsum        3.3.0
optax             0.1.1
pip               21.2.4
scipy             1.8.0
setuptools        58.0.4
six               1.16.0
toolz             0.11.2
typing_extensions 4.1.1
wheel             0.37.1

with

$ python --version
Python 3.8.12

Update: Python 3.7 seems to be working fine.

Although, I am not sure if moving to older Python versions is the right approach. It seems to me that parts of the code are written to use >= 3.8 Python version.

@danyaljj were you able to fine tune on python 3.7 successfully? I am trying to fine tune on 3.8 and getting the same error. Also I tried using python 3.7 but still am getting the same error