google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to pass dtype to `self.param(...)` in `flax.linen.Module`

patrickvonplaten opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

For a specific layernorm implementation, I'm tryning to pass dtype to self.param(....), but it seems that this is not possible.
I would like the user be able to switch between bfloat16, float16 and float32 for the weight.

How can one pass a dtype parameter to self.param(...)?

What you expected to happen:

I would like to implement the following layer:

class UserSpecificLayerNorm(nn.Module):
     hidden_size: in
     weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
     dtype: jnp.dtype = jnp.float32

def setup(self):
    self.weight = self.param("weight", self.weight_init, (self.hidden_size,), dtype=self.dtype) # it's not possible to pass dtype here

but it doesn't seem to be possible

commented

Not sure if it is still the case but in the past the jax.nn.initializers couldn't produce half types.
So you had to init a float32 weight and cast it manually afterwards.

In Flax each layer allows you to store intermediates in a specified dtype but params are in float32 by default. This is because parameters in half precision tend to be really difficult to optimize.
In your case for example storing the LayerNorm activations in half precision should get you up to a 2x speedup because the layernorm is memory bound. But the params are at least 2 orders of magnitude smaller than the activations so casting them to half as well should not give you even a 1% speedup while greatly increasing instability.

It is possible to take the params from an initialized network and cast them. This has some nice benefits like the ability to cast only in some specific cases. For example casting before evaluation is always a cheap win but you could also cast part of the network for a fine-tuning task.