Type annotations and runtime checking for:
For example:
from jaxtyping import Array, Float, PyTree
# Accepts floating-point 2D arrays with matching dimensions
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
...
def accepts_pytree_of_ints(x: PyTree[int]):
...
def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
...
pip install jaxtyping
Requires Python 3.8+.
JAX is an optional dependency, required for jaxtyping.{Array, ArrayLike, PyTree}
. If JAX is not installed then these types will not be available, but you may still use jaxtyping alongside PyTorch/NumPy/etc.
Also install your favourite runtime type-checking package. The two most popular are typeguard (which exhaustively checks every argument) and beartype (which checks random pieces of arguments).
FAQ (static type checking, flake8, etc.)
Neural networks: Equinox.
Numerical differential equation solvers: Diffrax.
Computer vision models: Eqxvision.
SymPy<->JAX conversion; train symbolic expressions via gradient descent: sympy2jax.
This is not an official Google product.