google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature Request] scaling factor for initializers

PhilipVinc opened this issue · comments

Machine-Learning practitioners are aware that initialisation functions must be properly rescaled depending on the variance of the output of the nonlinearity.

Howvever, scientistis like Mathematicians/Biologists/Physicists and many others who start looking into Differentiable programming and/or Scientific Machine Learning (or, as I call it, let's throw Neural networks at our problems and see how they perform) are relatively 'illiterate' about this kind of things.
If Flax is supposed to be targeted at this kind of users, maybe you could do something to more prominently make users aware of this issue?
I've seen many (NetKet) users mindlessly changing activation functions, and they had no idea they should also have rescaled
Maybe mentioning this in a tutorial/notebook might already be usefull.

Also, I recently was made aware of torch.nn.init.calculate_gain(nonlinearity, param=None) which is an useful utility returning the variance of the output of the activation fucntions.
Maybe flax could also export something similar?

Maybe, a DenseActivation layer which automatically rescaled the initialiser depending on the nonlinearity, if it is known?
That wouldn't work all the time, and maybe is too smart, but it would be very beginner friendly.

commented

Machine-Learning practitioners are aware that initialisation functions must be properly rescaled depending on the variance of the output of the nonlinearity.

I don't think this is still an up-to-date statement. There was a time when network training was very sensitive to the initializers because there were no skip connections or normalization.
Nowadays both of these techniques are almost universally applied and have made tuning the initializers largely irrelevant. It's been quite a while since I have seen a network that did have proper initialization correcting for the activation function.

That said, computing the gain might still be interesting, but it isn't essential like it used to be.