kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Typo in 'to_hf_weights.py '

AmoArt opened this issue · comments

In the line 461 ' with maps.mesh(devices, ("dp", "mp")):' should be written as ' with maps.Mesh(devices, ("dp", "mp")):' otherwise it gives error that jax.experimental.maps do not have attribute called mesh.

commented

jax.experimental.maps does have "mesh" as long as you have jax<=0.3.7:

❯ python3.9
Python 3.9.13 (main, Jun  8 2022, 09:45:57) 
[GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> import jaxlib
>>> jax.__version__
'0.2.12'
>>> jaxlib.__version__
'0.1.68'
>>> jax.experimental.maps.mesh
<function mesh at 0x7f6346581940>
>>> from jax.experimental import maps
>>> maps.mesh
<function mesh at 0x7f6346581940>