jax dependency error when jax is not installed
jsternabsci opened this issue · comments
jaxtyping 0.2.26 introduced a bug for using the @jaxtyped decorator without jax installed.
Error:
/opt/conda/envs/testenv-20743/lib/python3.11/site-packages/jaxtyping/_decorator.py:192: in jaxtyped
if _tb_flag and importlib.util.find_spec("jax._src.traceback_util") is not None:
E ModuleNotFoundError: No module named 'jax'
Breaking change:
v0.2.25...v0.2.26#diff-f792a47fc41c0cf008332f62022e6136a8f6c6e0514c743eedc7df172e519ce5R192
Thanks for the report! This should be fixed in #188, and I've just done a new release to include this.
My environment has jax
but not jaxlib
(don't ask me why 🤔).
This means that importlib.util.find_spec("jax")
succeeds, but importlib.util.find_spec("jax._src.traceback_util")
fails with ModuleNotFoundError: jax requires jaxlib to be installed.
.
Thoughts on this? Should jaxtyping
verify that jax
not only exists, but is import-able? Or it basically just user error to have jax
without jaxlib
.
Aha! That's unfortunate.
I think we might as well add a check for jaxlib as well. I'd be happy to take a PR on that?
Will do!