[BUG] Order sensitivity for implicit casting of jax Arrays with infix operators
amifalk opened this issue · comments
Adding shapeless Jax arrays to named arrays with infix notation works from the right, but not the left. This does not occur when you nmap the jnp.add
function directly, so a potential solution might be to replace the NamedArrayBase
dunder methods with the jax.numpy
version. e.g.
__add__ = _nmap_with_doc(operator.add, "jax.Array.__add__")
->
__add__ = _nmap_with_doc(jax.numpy.add,"jax.Array.__add__")
Reproducable example:
# Broken
import jax.numpy as jnp
from penzai import pz
def test_right_const(arr):
return arr + jnp.array(3.)
def test_left_const(arr):
return jnp.array(3.) + arr
arr = pz.nx.arange('arr', 3)
test_right_const(arr) # works
test_left_const(arr) # fails
ValueError: Only NamedArray(View)s with no named axes can be converted to JAX arrays. Use `unwrap` or `untag` to assign positions to named axes first, or use `penzai.named_axes.nmap` with a JAX function instead.
# Working
def no_infix_right_const(arr):
return pz.nx.nmap(jnp.add)(arr, jnp.array(3.))
def no_infix_left_const(arr):
return pz.nx.nmap(jnp.add)(jnp.array(3.), arr)
# these both work
no_infix_left_const(arr)
no_infix_right_const(arr)
Thanks for the report and the reproducing example!
Unfortunately, I don't think this has to do with jnp.add
vs operator.add
, but instead has to do with Python's infix resolution behavior. The explicit version works even if you use operator.add
:
def no_infix_right_const(arr):
return pz.nx.nmap(operator.add)(arr, jnp.array(3.))
def no_infix_left_const(arr):
return pz.nx.nmap(operator.add)(jnp.array(3.), arr)
# works
no_infix_left_const(arr)
no_infix_right_const(arr)
I believe what's happening here is:
- Python translates
a + b
toa.__add__(b)
and only triesb.__radd__(a)
ifa.__add__(b)
returnsNotImplemented
(described here) - If the leftmost object is a JAX array, JAX calls a wrapped version of
jnp.add
, which calls__jax_array__
if it exists - Penzai's
NamedArrayBase
defines__jax_array__
to try to unwrap and raise an error if there are still named axes.
I think fixing this will require removing the automatic-unwrap support for NamedArrays, so that JAX doesn't try to call __jax_array__
and instead returns NotImplemented
(in which case, Python should fall back to the __radd__
method).