How to pass dtype to `self.param(...)` in `flax.linen.Module`
patrickvonplaten opened this issue · comments
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
For a specific layernorm implementation, I'm tryning to pass dtype
to self.param(....)
, but it seems that this is not possible.
I would like the user be able to switch between bfloat16
, float16
and float32
for the weight.
How can one pass a dtype
parameter to self.param(...)
?
What you expected to happen:
I would like to implement the following layer:
class UserSpecificLayerNorm(nn.Module):
hidden_size: in
weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
dtype: jnp.dtype = jnp.float32
def setup(self):
self.weight = self.param("weight", self.weight_init, (self.hidden_size,), dtype=self.dtype) # it's not possible to pass dtype here
but it doesn't seem to be possible
Not sure if it is still the case but in the past the jax.nn.initializers couldn't produce half types.
So you had to init a float32 weight and cast it manually afterwards.
In Flax each layer allows you to store intermediates in a specified dtype but params are in float32 by default. This is because parameters in half precision tend to be really difficult to optimize.
In your case for example storing the LayerNorm activations in half precision should get you up to a 2x speedup because the layernorm is memory bound. But the params are at least 2 orders of magnitude smaller than the activations so casting them to half as well should not give you even a 1% speedup while greatly increasing instability.
It is possible to take the params from an initialized network and cast them. This has some nice benefits like the ability to cast only in some specific cases. For example casting before evaluation is always a cheap win but you could also cast part of the network for a fine-tuning task.