patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ImportError: cannot import name 'Array' from 'jaxtyping'

oadams opened this issue · comments

Oddly, I can't import Array.

$ pip freeze | grep jaxtyping
jaxtyping==0.2.23
$ python
Python 3.11.3 (main, Apr 26 2023, 07:55:34) [Clang 14.0.0 (clang-1400.0.29.202)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from jaxtyping import Array
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: cannot import name 'Array' from 'jaxtyping' (/Users/oadams/.pyenv/versions/gpt/lib/python3.11/site-packages/jaxtyping/__init__.py)

I can import Float but I can't import PyTree either.

You need to have JAX installed :)

Thanks!

I was just interested in using it for PyTorch. Upon a closer look at the docs I see that what I really wanted was Float[torch.Tensor, '...'].