nn.Embed cannot be hashed -> doesn't work with jax.jit static_argnums
jatentaki opened this issue · comments
Problem you have encountered:
There is some issue with hashing of nn.Embed
which means it cannot be used as input to methods annotated with jax.jit
. An example situation is when one wishes to have a train_step
function which is generic over the actual network executed: when you try to pass the model as a static argument, it works with stuff like nn.Dense
but not nn.Embed
.
What you expected to happen:
jax.jit
to work with static arguments including nn.Embed
.
Steps to reproduce:
This may contain some superfluous code (optax
and stuff) but I hope it conveys the idea clearly enough.
In Flax, we would not usually pass around function references as static argnums, but instead pass them in as part of a PyTree with the annotation that they should not be transformed.
In you case, the simplest solution would be to extend TrainState
and add the apply_embed_fn
attribute with that annotation:
from typing import Callable
from flax import struct
class TrainState(train_state.TrainState):
embed_apply_fn: Callable = struct.field(pytree_node=False)
Then you can initialize the state like this:
state = TrainState.create(
apply_fn=model.apply,
embed_apply_fn=embed.apply,
params=params,
tx=optax.adam(1e-3),
)
Which will reduce the parameter count for your train_step()
that now simply becomes
@jax.jit
def train_step(state, i):
def loss_fn(params):
y = state.embed_apply_fn(params['embed'], i)
x = state.apply_fn(params['model'], y)
# ...
As for a minimal repro we could say
import flax
hash(flax.linen.Dense(10)) # Works
hash(flax.linen.Embed(2, 3)) # Fails
The difference is due to a field that is not initialized and then the dataclass
-generated __hash__
function fails...
Line 402 in e30b7f5
As shown by
embed = flax.linen.Embed(2, 3)
object.__setattr__(embed, 'embedding', None)
hash(embed) # Works
Tagging @jheek here who introduced above embedding: Array = field(init=False)
in #643
@andsteing thanks, that looks like a solution. May I ask for the rationale behind adopting this pattern though? I'm thinking of pytrees as a way to store the state of computation and while it may be convenient to be able to have non-transformed fields for some edge cases, the approach above feels to me like a hack. After all, if we put both the state and implementation in pytrees, what is the purpose of nn.Modules? Should I think of them as just a factory function, used to generate the pytree which then contains the entire API of my model?
Secondly, how does the non-transformed property play with jax.jit? After all, this apply_xyz functions are what we are looking to transform with jit. The approach you're proposing requires jax to figure out the code is static even though it's passed through a field we don't annotate as such. Are functions special cased as always static? After all, they may have closed on arbitrary mutable state.
I'm sorry if I sound critical, I'm just trying to align my intuition about how to use flax with that of its creators. Thank you very much.
Yes, it's a convenience way of passing a mix of parameters and functions through transformations like jit()
and pmap()
- note that even though you don't specify apply_fn
you're already making use of this pattern when calling state.apply_gradients()
which uses state.tx
internally:
flax/flax/training/train_state.py
Line 55 in e30b7f5
There is some discussion about this pattern in FLIP 1009, where you can also see alternatives.
There is nothing wrong about passing in all the functions as static argnums (or referring to them through an outer scope), but it can become quite verbose and that's why we prefer this dataclass-transform/notransform pattern in our projects (e.g. our examples).
As for the purpose of nn.Module
, after having things set up and initialized, most modules are really only used through .apply_fn()
- not a factory pattern in the classic sense, but for many modules (like Dense
and Embed
) you could see the whole nn.Module
machinery (that allows to nest modules, sets up and tracks scope, updates RNG key chains, stores parameters etc) "producing" a single function in the end (or two in the case of Embed
).
As for your second question, your function function can indeed close on arbitrary mutable state, and that's a bad idea regardless whether you pass it via static_argums
or via a pytree dataclass field that has pytree_node=False
. JAX expects you to transform pure functions, and that includes all functions you call from inside those transformed functions, regardless how they're passed into the function - if you're not transforming pure functions you're breaking the contract and there are no guarantees as to what your transformed functions will actually do (in some cases you might get an error transforming such a function, but in many cases JAX will silently comply).
Thanks once again. I suppose I leave this issue open in case @jhee decides there's something to be changed about nn.Embed but on my side the issue is resolved.
@jheek - see above request for comment from jatentaki (your handle was mis-spelled)
Thanks, I created a PR to fix the issue