kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'

umm-maybe opened this issue · comments

Hello, I have followed the (very much appreciated) howto_finetune.md guide and, upon attempting to run the magic python device_train.py command, received the error noted above. The only Google search result that seems to mention something similar is this: https://bytemeta.vip/repo/deepmind/alphafold/issues/515

The answer to that question seems to imply it has to do with a version incompatibility between jax and jaxlib, but the solution they link to doesn't work here. Any tips or advice for working around this would be greatly appreciated!

From the top of my head; pip install jax==0.2.12 jaxlib==0.1.67
Can not try right now, but that version combination should work on TPU-VM.

Edit: I think it also has to do with what Python (3.7 on TPU v2 and Colab, 3.8 on v3) version you're running and what TPU-version / accelerator-type. I think I've seen jaxlib==0.1.68 in v2 setups, so also worth a shot.

I'm also using a TPU v2 setup and ran into this problem. I used the JAX TPU install instructions from their README and it worked for me.

Now I'm also getting "AttributeError: module 'jax' has no attribute 'version'"... Or, also:
AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'. Tried couple of different colab notebooks... Doesn't work...

I fixed it by doing this right after the install dependencies section:

!pip install jaxlib==0.1.67

And restart the runtime if it asks

Though it feels so fragile. Don't know why