google / trax

Trax — Deep Learning with Clear Code and Speed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AttributeError: module 'jax.ops' has no attribute 'index_add'

cmosguy opened this issue · comments

Description

I am trying to do something basic in my code:

import numpy as np              # regular ol' numpy
from trax import layers as tl   # core building block
from trax import shapes         # data signatures: dimensionality and type
from trax import fastmath       # uses jax, offers numpy on steroids

Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?

Environment information

OS: Cento
lsb_release
LSB Version: :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch

$ pip freeze | grep trax
trax==1.3.9

$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.2
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0

$ pip freeze | grep jax
jax==0.4.4
jaxlib==0.4.4

$ python -V
Python 3.9.16


### For bugs: reproduction and error logs

# Error logs:

...

      1 # coding=utf-8
      2 # Copyright 2021 The Trax Authors.
      3 #
   (...)
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     16 """Trax top level import."""
---> 18 from trax import data
     19 from trax import fastmath
     20 from trax import layers

File ./ds_work/miniconda3/envs/coursera-nlp/lib/python3.9/site-packages/trax/data/__init__.py:36, in <module>
     16 """Functions and classes for obtaining and preprocesing data.
     17 
     18 The ``trax.data`` module presents a flattened (no subpackages) public API.
   (...)
...
    217     'vjp': jax.vjp,
    218     'vmap': jax.vmap,
    219 }

AttributeError: module 'jax.ops' has no attribute 'index_add'

downgrade jax to 0.2.21
jax.ops.index_add is deprecated in 0.2.22
https://gitee.com/mirrors/JAX/blob/main/CHANGELOG.md#jax-0222-oct-12-2021