kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning.

samyakai opened this issue · comments

I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jax.random' has no attribute 'KeyArray'". These are some of the specs:

OS: Ubuntu 20.04
jax version = 0.2.12
TPU : V3-8
Zone : us-central1-b

The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

This is the error stack:

WARNING: Logging before InitGoogle() is written to STDERR
I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

Any help is appreciated!

I've just encountered exactly the same error and I was about to open an issue about this.

commented

WARNING: Logging before InitGoogle() is written to STDERR
I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

I am facing the same error. Kindly solve this!

The same error is facing while "import optax".

WARNING: Logging before InitGoogle() is written to STDERR
I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in
from optax._src.alias import adabelief
File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in
from optax._src import base
File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

Exact same issue here!

commented

same issue!

commented

Chex 0.1.3 doesn't support JAX 0.2.12. You need to downgrade to Chex 0.1.2:

pip3 install chex==0.1.2

@vfbd It worked for me to infer the model. but apparently not for finetuning.

@mosmos6 @vfbd Now it is giving me this error: "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'".
@mosmos6 How did it work for you? Are you training on TPU v3-8?

@samyakai Now I noticed you encountered this error on fine tune. I did on inference but the same error. Sorry for the confusion. I modified my previous comment. The issue hasn't been resolved for finetuning.

As suggested by @vfbd if I downgrade chex to 0.1.2 I encounter " "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'".". To overcome this https://github.com/google/brax/issues/187 suggests upgrading to latest version. if I do that I again encounter the error "AttributeError: module 'jax.random' has no attribute 'KeyArray'" .

I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice''". These are some of the specs:

OS: Ubuntu 20.04
jax version = 0.2.12
chex version == 0.1.2

TPU : V3-8
Zone : us-central1-b

The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

This is the error stack:

WARNING: Logging before InitGoogle() is written to STDERR
I0421 10:06:19.047791 8679 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which pr ocess is using the TPU. Not attempting to load libtpu.so in this process.
Traceback (most recent call last):
File "device_train.py", line 7, in
import optax
File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in < module>
from optax import experimental
File "/usr/local/lib/python3.8/dist-packages/optax/experimental/init.py", line 20, in
from optax._src.experimental.complex_valued import split_real_and_imaginary
File "/usr/local/lib/python3.8/dist-packages/optax/_src/experimental/complex_v alued.py", line 32, in
import chex
File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in
from chex._src.asserts import assert_axis_dimension
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, i n
from chex._src import asserts_internal as _ai
File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", l ine 32, in
from chex._src import pytypes
File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 40, i n
CpuDevice = jax.lib.xla_extension.CpuDevice
AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'

Please help us to resolve it asap..
Thank you

commented

Which version of jaxlib (not jax) do you have? Maybe try again with jaxlib==0.1.68

@vfbd These are the library versions which solve the error.
jax==0.2.16
jaxlib==0.1.68
optax==0.1.2
chex==0.1.2

I just paid for colab pro to play around with this and found the same issues described here. I added !pip install for the lib versions mentioned and then I got this error:

`---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
in
4 from jax.experimental import maps
5 import numpy as np
----> 6 import optax
7 import transformers
8

6 frames
/usr/local/lib/python3.8/dist-packages/jax/_src/api.py in
42 from . import dtypes
43 from ..core import eval_jaxpr
---> 44 from ..api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
45 flatten_fun_nokwargs2, argnums_partial,
46 argnums_partial_except, flatten_axes, donation_vector,

ImportError: cannot import name '_ensure_str_tuple' from 'jax.api_util' (/usr/local/lib/python3.8/dist-packages/jax/api_util.py)


NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------`

I appreciate that this is like, alpha, so while I'll go play with GTP3, thank you for your work.


from jax_md import rigid_body
File "C:\Users....\env\Lib\site-packages\jax_md\rigid_body.py", line 76, in
KeyArray = random.KeyArray

module 'jax.random' has no attribute 'KeyArray'

I get this error when trying to run import jax_md