google-deepmind / penzai

A JAX research toolkit for building, editing, and visualizing neural networks.

Home Page:https://penzai.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] Order sensitivity for implicit casting of jax Arrays with infix operators

amifalk opened this issue · comments

Adding shapeless Jax arrays to named arrays with infix notation works from the right, but not the left. This does not occur when you nmap the jnp.add function directly, so a potential solution might be to replace the NamedArrayBase dunder methods with the jax.numpy version. e.g.

__add__ = _nmap_with_doc(operator.add, "jax.Array.__add__") ->
__add__ = _nmap_with_doc(jax.numpy.add,"jax.Array.__add__")

Reproducable example:

# Broken
import jax.numpy as jnp
from penzai import pz

def test_right_const(arr):
    return arr + jnp.array(3.) 

def test_left_const(arr):
    return jnp.array(3.) + arr     

arr = pz.nx.arange('arr', 3)

test_right_const(arr) # works
test_left_const(arr) # fails
ValueError: Only NamedArray(View)s with no named axes can be converted to JAX arrays. Use `unwrap` or `untag` to assign positions to named axes first, or use `penzai.named_axes.nmap` with a JAX function instead.
# Working
def no_infix_right_const(arr):
    return pz.nx.nmap(jnp.add)(arr, jnp.array(3.))

def no_infix_left_const(arr):
    return pz.nx.nmap(jnp.add)(jnp.array(3.), arr)

# these both work
no_infix_left_const(arr)
no_infix_right_const(arr)

Thanks for the report and the reproducing example!

Unfortunately, I don't think this has to do with jnp.add vs operator.add, but instead has to do with Python's infix resolution behavior. The explicit version works even if you use operator.add:

def no_infix_right_const(arr):
    return pz.nx.nmap(operator.add)(arr, jnp.array(3.))

def no_infix_left_const(arr):
    return pz.nx.nmap(operator.add)(jnp.array(3.), arr)

# works
no_infix_left_const(arr)
no_infix_right_const(arr)

I believe what's happening here is:

  • Python translates a + b to a.__add__(b) and only tries b.__radd__(a) if a.__add__(b) returns NotImplemented (described here)
  • If the leftmost object is a JAX array, JAX calls a wrapped version of jnp.add, which calls __jax_array__ if it exists
  • Penzai's NamedArrayBase defines __jax_array__ to try to unwrap and raise an error if there are still named axes.

I think fixing this will require removing the automatic-unwrap support for NamedArrays, so that JAX doesn't try to call __jax_array__ and instead returns NotImplemented (in which case, Python should fall back to the __radd__ method).