google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nn.Embed cannot be hashed -> doesn't work with jax.jit static_argnums

jatentaki opened this issue · comments

Problem you have encountered:

There is some issue with hashing of nn.Embed which means it cannot be used as input to methods annotated with jax.jit. An example situation is when one wishes to have a train_step function which is generic over the actual network executed: when you try to pass the model as a static argument, it works with stuff like nn.Dense but not nn.Embed.

What you expected to happen:

jax.jit to work with static arguments including nn.Embed.

Steps to reproduce:

This may contain some superfluous code (optax and stuff) but I hope it conveys the idea clearly enough.

In Flax, we would not usually pass around function references as static argnums, but instead pass them in as part of a PyTree with the annotation that they should not be transformed.

In you case, the simplest solution would be to extend TrainState and add the apply_embed_fn attribute with that annotation:

from typing import Callable

from flax import struct

class TrainState(train_state.TrainState):
  embed_apply_fn: Callable = struct.field(pytree_node=False)

Then you can initialize the state like this:

    state = TrainState.create(
        apply_fn=model.apply,
        embed_apply_fn=embed.apply,
        params=params,
        tx=optax.adam(1e-3),
    )

Which will reduce the parameter count for your train_step() that now simply becomes

@jax.jit
def train_step(state, i):
    def loss_fn(params):
        y = state.embed_apply_fn(params['embed'], i)
        x = state.apply_fn(params['model'], y)
        # ...

As for a minimal repro we could say

import flax
hash(flax.linen.Dense(10))  # Works
hash(flax.linen.Embed(2, 3))  # Fails

The difference is due to a field that is not initialized and then the dataclass-generated __hash__ function fails...

embedding: Array = field(init=False)

As shown by

embed = flax.linen.Embed(2, 3)
object.__setattr__(embed, 'embedding', None)
hash(embed)  # Works

Tagging @jheek here who introduced above embedding: Array = field(init=False) in #643

@andsteing thanks, that looks like a solution. May I ask for the rationale behind adopting this pattern though? I'm thinking of pytrees as a way to store the state of computation and while it may be convenient to be able to have non-transformed fields for some edge cases, the approach above feels to me like a hack. After all, if we put both the state and implementation in pytrees, what is the purpose of nn.Modules? Should I think of them as just a factory function, used to generate the pytree which then contains the entire API of my model?
Secondly, how does the non-transformed property play with jax.jit? After all, this apply_xyz functions are what we are looking to transform with jit. The approach you're proposing requires jax to figure out the code is static even though it's passed through a field we don't annotate as such. Are functions special cased as always static? After all, they may have closed on arbitrary mutable state.

I'm sorry if I sound critical, I'm just trying to align my intuition about how to use flax with that of its creators. Thank you very much.

Yes, it's a convenience way of passing a mix of parameters and functions through transformations like jit() and pmap() - note that even though you don't specify apply_fn you're already making use of this pattern when calling state.apply_gradients() which uses state.tx internally:

tx: optax.GradientTransformation = struct.field(pytree_node=False)

There is some discussion about this pattern in FLIP 1009, where you can also see alternatives.

There is nothing wrong about passing in all the functions as static argnums (or referring to them through an outer scope), but it can become quite verbose and that's why we prefer this dataclass-transform/notransform pattern in our projects (e.g. our examples).

As for the purpose of nn.Module, after having things set up and initialized, most modules are really only used through .apply_fn() - not a factory pattern in the classic sense, but for many modules (like Dense and Embed) you could see the whole nn.Module machinery (that allows to nest modules, sets up and tracks scope, updates RNG key chains, stores parameters etc) "producing" a single function in the end (or two in the case of Embed).

As for your second question, your function function can indeed close on arbitrary mutable state, and that's a bad idea regardless whether you pass it via static_argums or via a pytree dataclass field that has pytree_node=False. JAX expects you to transform pure functions, and that includes all functions you call from inside those transformed functions, regardless how they're passed into the function - if you're not transforming pure functions you're breaking the contract and there are no guarantees as to what your transformed functions will actually do (in some cases you might get an error transforming such a function, but in many cases JAX will silently comply).

Thanks once again. I suppose I leave this issue open in case @jhee decides there's something to be changed about nn.Embed but on my side the issue is resolved.

@jheek - see above request for comment from jatentaki (your handle was mis-spelled)

commented

Thanks, I created a PR to fix the issue