patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jax dependency error when jax is not installed

jsternabsci opened this issue · comments

jaxtyping 0.2.26 introduced a bug for using the @jaxtyped decorator without jax installed.

Error:

/opt/conda/envs/testenv-20743/lib/python3.11/site-packages/jaxtyping/_decorator.py:192: in jaxtyped
    if _tb_flag and importlib.util.find_spec("jax._src.traceback_util") is not None:
E   ModuleNotFoundError: No module named 'jax'

Breaking change:
v0.2.25...v0.2.26#diff-f792a47fc41c0cf008332f62022e6136a8f6c6e0514c743eedc7df172e519ce5R192

Thanks for the report! This should be fixed in #188, and I've just done a new release to include this.

My environment has jax but not jaxlib (don't ask me why 🤔).

This means that importlib.util.find_spec("jax") succeeds, but importlib.util.find_spec("jax._src.traceback_util") fails with ModuleNotFoundError: jax requires jaxlib to be installed..

Thoughts on this? Should jaxtyping verify that jax not only exists, but is import-able? Or it basically just user error to have jax without jaxlib.

Aha! That's unfortunate.
I think we might as well add a check for jaxlib as well. I'd be happy to take a PR on that?

Will do!