google-deepmind / penzai

A JAX research toolkit for building, editing, and visualizing neural networks.

Home Page:https://penzai.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Keep axes in nmap

srush opened this issue · comments

Is there anyway to specify axes that you want to push into the nmap? Currently it feels like I need to go fully positional inside of a function.

I guess it would be something like

x=x.tag("foo", "bar")
vmap(f) (x.untag("foo"))

Where f can see bar?

Can you expand on what you are trying to do? This is something I considered adding but I couldn't figure out what use cases this would support.

Most operations on named arrays already vectorize automatically over existing axis names, so it seems like you could just pass x to f and let f operate on whichever axes it needs? If you want f to be generic over axis names, you could pass "bar" as an explicit argument to f to tell it which axes it should operate over, and not use vmap here at all. This is how most of Penzai's existing layers work.

I'll also note that if you absolutely need to do something like this, you can actually already do

vmap(f)(
    x.untag("foo").with_positional_prefix()
).tag("foo")

and it should just work. with_positional_prefix moves the untagged axes to the front so that you can map over them with ordinary JAX pytree manipulation.

This came up in my code because I had an Array with "row", "col", "val" names. I had a function that used the "val" axis, however I wanted it to work by vmapping over either "row" or "col". In Jax that was clean. But in Penzai I couldn't figure out how to do it nicely. One option was to make the function purely positional. The other was to pass in the vmap name as a static arg. Neither seemed as clean as I wanted.

btw, I'm having a blast with the library. Just worked through the tensor-puzzles in pure Penzai.

https://srush.github.io/Tensor-Puzzles-Penzai/Tensor_Puzzlers_Penzai.html

Hm, and why couldn't you just call the function on the full array without vmapping, and just untag "val" inside the function? Were you trying to make the function generic over the name of the axis instead of always having it be called "val"? Or is it using the axis names in some other way?

(And I'm glad you're enjoying using Penzai!)

Yeah, I guess what you are saying would work fine. Maybe my concern is mostly aesthetic. Would like to be able to hide a named axis, let's say for example "batch", from the inner function if I am effectively vmapping over it. Maybe you are saying that this is not really a problem since it is much harder to use these extra axes in Penzai than in numpy notation.

Yeah, so far I haven't run into a situation where having an extra axis name causes a problem, since you usually need to know the name of an axis in order to operate on it. So it's usually possible to write code that is automatically polymorphic over the extra batch axes without requiring a function transformation.

I guess one exception is if you're worried about axis name conflicts. But it's allowed to put arbitrary hashable objects as axis names, so you can effectively hide an axis by tagging it with a fresh Python object() (which is equal only to itself) and then untagging it later, if you're worried about accidental conflicts.

(If there are concrete reasons why this doesn't work in all cases, though, I'd be open to revisiting this!)