How do I lift the parameters out of a module?
NeilGirdhar opened this issue · comments
I've been considering switching from Haiku to Flax, but I ran into a couple of things that seem to be impossible in Flax. One thing is the ability to "lift" parameters within a transformation, as described in google-deepmind/dm-haiku#98.
In my code, I do something like:
class SomeModule:
def __call__(self, message: PredictionMessage) -> PredictionMessage:
layer_f = hk.transform(layer)
init_rng = hk.next_rng_key() if hk.running_init() else None
weights = hk.experimental.lift(layer_f.init, name='f_lift')(init_rng,
message.observation)
message = _apply_layer(layer_f.apply, weights, message)
@partial(custom_vjp[PredictionMessage], nondiff_argnums=(0,))
def _apply_layer(f: Callable[[hk.Params, None, RealArray], RealArray],
weights: hk.Params,
message: PredictionMessage) -> PredictionMessage:
...
(Incidentally, how do I port over the init_rng
definition?)
Hey Neil! I don't fully understand your code but if you need a vjp
function compatible with Flax modules you would use flax.linen.vjp.
@cgarciae were you able to make sense of the linked issue (google-deepmind/dm-haiku#98)? (flax.linen.vjp
doesn't help unfortunately.)
(Incidentally, how do I port over the init_rng definition?)
I don't fully understand your code, but if you want to define a param in your Module you don't have to provide any RNGs, you just provide an init function and the RNG gets split from the top-level RNG that you provide when you initialize a Module. If this is still unclear, please file the question in a separate Github Discussion, to avoid confusion!
Duplicate of #1713
@NeilGirdhar there's a pattern in Flax that is very similair to haiku.lift that you could try in the meantime: https://flax.readthedocs.io/en/latest/design_notes/lift.html#functionalization
this doc explains how the lifted transformations work in Flax and how you can do something like this manually (in simple cases).
For example (snippet from the doc):
class ManualVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs):
mlp = MLP(parent=None)
init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
apply_fn = jax.vmap(mlp.apply, in_axes=0)
mlp_params = self.param('mlp', init_fn, xs)
return apply_fn({'params': mlp_params}, xs)
xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.PRNGKey(0), xs)
print(jax.tree_map(jnp.shape, variables['params']))
"""==>
{
mlp: {
hidden: {
bias: (3, 4),
kernel: (3, 4, 4),
},
out: {
bias: (3, 1),
kernel: (3, 4, 1),
},
},
}
"""