ImportError: cannot import name 'Array' from 'jaxtyping'
oadams opened this issue · comments
Oliver Adams commented
Oddly, I can't import Array.
$ pip freeze | grep jaxtyping
jaxtyping==0.2.23
$ python
Python 3.11.3 (main, Apr 26 2023, 07:55:34) [Clang 14.0.0 (clang-1400.0.29.202)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxtyping import Array
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ImportError: cannot import name 'Array' from 'jaxtyping' (/Users/oadams/.pyenv/versions/gpt/lib/python3.11/site-packages/jaxtyping/__init__.py)
I can import Float but I can't import PyTree either.
Patrick Kidger commented
You need to have JAX installed :)
Oliver Adams commented
Thanks!
I was just interested in using it for PyTorch. Upon a closer look at the docs I see that what I really wanted was Float[torch.Tensor, '...']
.