Initialization documentation for Conv seems to be wrong
sigeisler opened this issue · comments
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
The docs say that the default kernel_init=<function variance_scaling.<locals>.init>
. But in the code it says default_kernel_init = lecun_normal()
. I think the same is true for the other linear layers as well.
Or what am I missing?
This is really just a limitation of auto-generated docs. The lecun initializer in JAX is defined as something like:
le_cun = functools.partial(variance_scaling, "fan_in", "normal")
The downside of these shortcut definitions is that now le_cun is not a function with the name "le_cun" but the weird string you mentioned above. I don't think there's an easy way for us to infer a proper name for these partially applied functions that come from JAX. You could open an issue on their side. I think we would probably just write it out as:
def le_cun(...):
return variance_scaling("fan_in", "normal", ...)
just to get better docs and names