*Module Parameters* section of docs is outdated.
PaulScemama opened this issue · comments
Hi, first off thanks for a great library -- flax is awesome.
I wanted to revisit the documentation to gain a better understanding of flax. In basics there is a section on module parameters.
I wanted to point out that it would appear as though the code seems to not work at the moment.
Here is a stripped version of what is currently in the docs
import flax.linen as nn
import jax.numpy as jnp
import jax.random as random
class SimpleDense(nn.Module):
features: int
kernel_init = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init, # Initialization function
(inputs.shape[-1], self.features)) # init_args
y = jnp.dot(inputs, kernel)
return y
x = jnp.ones((1, 7))
model = SimpleDense(features=3)
key, init_key = random.split(random.key(123))
params = model.init(init_key, x)
# Error: TypeError: Cannot interpret '7' as a data type
Seems to be something to do with how *init_args
is being unpacked. I tried reproducing similar behaviour with the following
initializer = nn.initializers.glorot_normal()
def foo(rng_key, args):
def initialize():
return nn.initializers.glorot_normal()(rng_key, *args)
return initialize()
foo(random.key(1), (4,5))
# TypeError: Cannot interpret '5' as a data type
But I had trouble navigating the flax codebase as I am unfamiliar with it. Thanks again!
You need to specify a type annotation to the dataclass field:
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
...
@chiamp thanks!
I also think maybe an error message for not type annotating the dataclass field may be good, since the error message that came from it was a bit cryptic.
Not adding a type annotation turns kernel_init
into a class method:
class SimpleDense(nn.Module):
features: int
kernel_init = nn.initializers.lecun_normal()
SimpleDense.kernel_init(jax.random.key(0), (1, 1)) == nn.initializers.lecun_normal()(jax.random.key(0), (1, 1))
I believe there are use-cases for these, but @cgarciae can speak more to this.
Ahh I see @chiamp. So when we don't type annotate kernel_init
, it becomes a bound method. E.g.
from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
import jax.random as random
class SimpleDense(nn.Module):
features: int
kernel_init= nn.initializers.lecun_normal()
x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <bound method variance_scaling.<locals>.init of SimpleDense(
# # attributes
# features = 3
# )>
And then when we type annotate, it is only an attribute of the class (not bound).
from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
import jax.random as random
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <function variance_scaling.<locals>.init at 0x7f498fb66200>
In the former case, this boundedness messed up the order of the passing in the arguments to it during the initialization of self.param
(see top of thread).