google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Initialization documentation for Conv seems to be wrong

sigeisler opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

The docs say that the default kernel_init=<function variance_scaling.<locals>.init>. But in the code it says default_kernel_init = lecun_normal(). I think the same is true for the other linear layers as well.

Or what am I missing?

commented

This is really just a limitation of auto-generated docs. The lecun initializer in JAX is defined as something like:

le_cun = functools.partial(variance_scaling, "fan_in", "normal")

The downside of these shortcut definitions is that now le_cun is not a function with the name "le_cun" but the weird string you mentioned above. I don't think there's an easy way for us to infer a proper name for these partially applied functions that come from JAX. You could open an issue on their side. I think we would probably just write it out as:

def le_cun(...):
  return variance_scaling("fan_in", "normal", ...)

just to get better docs and names