google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How do I lift the parameters out of a module?

NeilGirdhar opened this issue · comments

I've been considering switching from Haiku to Flax, but I ran into a couple of things that seem to be impossible in Flax. One thing is the ability to "lift" parameters within a transformation, as described in google-deepmind/dm-haiku#98.

In my code, I do something like:

class SomeModule:
    def __call__(self, message: PredictionMessage) -> PredictionMessage:
        layer_f = hk.transform(layer)
        init_rng = hk.next_rng_key() if hk.running_init() else None
        weights = hk.experimental.lift(layer_f.init, name='f_lift')(init_rng,
                                                                    message.observation)

        message = _apply_layer(layer_f.apply, weights, message)

@partial(custom_vjp[PredictionMessage], nondiff_argnums=(0,))
def _apply_layer(f: Callable[[hk.Params, None, RealArray], RealArray],
                 weights: hk.Params,
                 message: PredictionMessage) -> PredictionMessage:
    ...

(Incidentally, how do I port over the init_rng definition?)

Hey Neil! I don't fully understand your code but if you need a vjp function compatible with Flax modules you would use flax.linen.vjp.

@cgarciae were you able to make sense of the linked issue (google-deepmind/dm-haiku#98)? (flax.linen.vjp doesn't help unfortunately.)

(Incidentally, how do I port over the init_rng definition?)

I don't fully understand your code, but if you want to define a param in your Module you don't have to provide any RNGs, you just provide an init function and the RNG gets split from the top-level RNG that you provide when you initialize a Module. If this is still unclear, please file the question in a separate Github Discussion, to avoid confusion!

Duplicate of #1713

commented

@NeilGirdhar there's a pattern in Flax that is very similair to haiku.lift that you could try in the meantime: https://flax.readthedocs.io/en/latest/design_notes/lift.html#functionalization

this doc explains how the lifted transformations work in Flax and how you can do something like this manually (in simple cases).

For example (snippet from the doc):

class ManualVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP(parent=None)
    init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
    apply_fn = jax.vmap(mlp.apply, in_axes=0)
    mlp_params = self.param('mlp', init_fn, xs)
    return apply_fn({'params': mlp_params}, xs)

xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.PRNGKey(0), xs)
print(jax.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        hidden: {
            bias: (3, 4),
            kernel: (3, 4, 4),
        },
        out: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""