google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ImportError: cannot import name 'index' from 'jax.ops'

Tusay opened this issue · comments

commented

I'm trying to work through the tutorial here: http://secondearths.sakura.ne.jp/exojax/tutorials/optimize_spectrum_JAXopt.html

And when I get to this block of code, I get an error. I'm using a GPU on google colab and I've confirmed that it's running.

from exojax.spec.lpf import xsmatrix
from exojax.spec.exomol import gamma_exomol
from exojax.spec.hitran import SijT, doppler_sigma, gamma_natural, gamma_hitran
from exojax.spec.hitrancia import read_cia, logacia
from exojax.spec.rtransfer import rtrun, dtauM, dtauCIA, nugrid
from exojax.spec import planck, response
from exojax.spec.lpf import xsvector
from exojax.spec import molinfo
from exojax.utils.constants import RJ, pc, Rs, c

pip show jax outputs:
Name: jax
Version: 0.3.4
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /usr/local/lib/python3.7/dist-packages
Requires: typing-extensions, scipy, absl-py, opt-einsum, numpy
Required-by: numpyro, jaxopt, exojax

Full error messages/tracebacks:


ImportError Traceback (most recent call last)
in ()
4 from exojax.spec.hitrancia import read_cia, logacia
5 from exojax.spec.rtransfer import rtrun, dtauM, dtauCIA, nugrid
----> 6 from exojax.spec import planck, response
7 from exojax.spec.lpf import xsvector
8 from exojax.spec import molinfo

/usr/local/lib/python3.7/dist-packages/exojax/spec/init.py in ()
15 )
16
---> 17 from exojax.spec.autospec import (
18 AutoXS,
19 AutoRT,

/usr/local/lib/python3.7/dist-packages/exojax/spec/autospec.py in ()
1 """Automatic Opacity and Spectrum Generator."""
2 import time
----> 3 from exojax.spec import defmol, defcia, moldb, contdb, planck, molinfo, lpf, dit, modit, initspec, response
4 from exojax.spec.opacity import xsection
5 from exojax.spec.hitran import SijT, doppler_sigma, gamma_natural, gamma_hitran, normalized_doppler_sigma

/usr/local/lib/python3.7/dist-packages/exojax/spec/dit.py in ()
10 from jax.lax import scan
11 from exojax.spec.ditkernel import fold_voigt_kernel
---> 12 from jax.ops import index as joi
13 from exojax.spec.atomll import padding_2Darray_for_each_atom
14 from exojax.spec.rtransfer import dtauM

ImportError: cannot import name 'index' from 'jax.ops' (/usr/local/lib/python3.7/dist-packages/jax/ops/init.py)

I'm not sure how to resolve this issue.

Thanks for the report! This was deprecated in JAX version 0.2.22 and removed in version 0.3.2 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-032-march-16-2022)

Instead of jax.ops.index, we recommend jnp.index_exp (which is essentially identical).

If you're depending on another project that is attempting to import this, you'll have to downgrade to JAX 0.3.1 or older until the package using it can be updated.