`apply` requires a PRNG key when it is not needed.
akolesnikoff opened this issue · comments
Problem you have encountered:
apply
call requires to provide a PRNG key, which is only used in setup()
for Variable
initialization. As far as I can tell the key is not really needed and providing it makes code more complicated and less intuitive.
What you expected to happen:
I expect apply
not to require a PRNG key.
Steps to reproduce:
Minimal example:
class M(nn.Module):
def setup(self):
self.p = self.variable('state', 'state', nn.initializers.normal(),
self.make_rng('state'), (2, 2))
def __call__(self, x):
return self.p.value
m = M()
p = m.init({'state': jax.random.PRNGKey(0)}, 0.0)
m.apply(p, 0.0)
Error:
---------------------------------------------------------------------------
InvalidRngError Traceback (most recent call last)
<ipython-input-120-e878a6c947e2> in <module>()
11 p = m.init({'state': jax.random.PRNGKey(0)}, None)
12
---> 13 m.apply(p, 0.0)
10 frames
google3/third_party/py/flax/core/scope.py in make_rng(self, name)
500 """Generates A PRNGKey from a PRNGSequence with name `name`."""
501 if not self.has_rng(name):
--> 502 raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
503 self._check_valid()
504 self._validate_trace_level()
InvalidRngError: None needs PRNG for "state" (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.InvalidRngError)
This is expected behavior: setup is called during apply but should normally not create PRNGKeys outside of variable initialization.
The make_rng API is independent of variable initialization for example you could use it like dropout_rng = self.make_rng("dropout")
.
For initializers we normally use something equivalent toself.variable('state, 'state', lambda *args: init_fn(self.make_rng("state"), *args), shape)
The init_fn will not be called during apply because the variable already has a value.
Thanks, calling make_rng in the initializer on demand solved it for me.
Hello, I don't quite get the idea.
I went through the documentation but I couldn't understand the behavior of make_rng
or even nn.compact
, it seems that it is doing some magic I can not easily reproduce in a setup(self):
definition