AttributeError: module 'jax.ops' has no attribute 'index_add'
cmosguy opened this issue · comments
Description
I am trying to do something basic in my code:
import numpy as np # regular ol' numpy
from trax import layers as tl # core building block
from trax import shapes # data signatures: dimensionality and type
from trax import fastmath # uses jax, offers numpy on steroids
Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?
Environment information
OS: Cento
lsb_release
LSB Version: :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch
$ pip freeze | grep trax
trax==1.3.9
$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.2
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0
$ pip freeze | grep jax
jax==0.4.4
jaxlib==0.4.4
$ python -V
Python 3.9.16
### For bugs: reproduction and error logs
# Error logs:
...
1 # coding=utf-8
2 # Copyright 2021 The Trax Authors.
3 #
(...)
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
16 """Trax top level import."""
---> 18 from trax import data
19 from trax import fastmath
20 from trax import layers
File ./ds_work/miniconda3/envs/coursera-nlp/lib/python3.9/site-packages/trax/data/__init__.py:36, in <module>
16 """Functions and classes for obtaining and preprocesing data.
17
18 The ``trax.data`` module presents a flattened (no subpackages) public API.
(...)
...
217 'vjp': jax.vjp,
218 'vmap': jax.vmap,
219 }
AttributeError: module 'jax.ops' has no attribute 'index_add'
downgrade jax to 0.2.21
jax.ops.index_add is deprecated in 0.2.22
https://gitee.com/mirrors/JAX/blob/main/CHANGELOG.md#jax-0222-oct-12-2021