google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`apply` requires a PRNG key when it is not needed.

akolesnikoff opened this issue · comments

Problem you have encountered:

apply call requires to provide a PRNG key, which is only used in setup() for Variable initialization. As far as I can tell the key is not really needed and providing it makes code more complicated and less intuitive.

What you expected to happen:

I expect apply not to require a PRNG key.

Steps to reproduce:

Minimal example:

class M(nn.Module):

  def setup(self):
    self.p = self.variable('state', 'state', nn.initializers.normal(),
                           self.make_rng('state'), (2, 2))
  
  def __call__(self, x):
    return self.p.value

m = M()
p = m.init({'state': jax.random.PRNGKey(0)}, 0.0)

m.apply(p, 0.0)

Error:

---------------------------------------------------------------------------
InvalidRngError                           Traceback (most recent call last)
<ipython-input-120-e878a6c947e2> in <module>()
     11 p = m.init({'state': jax.random.PRNGKey(0)}, None)
     12 
---> 13 m.apply(p, 0.0)

10 frames
google3/third_party/py/flax/core/scope.py in make_rng(self, name)
    500     """Generates A PRNGKey from a PRNGSequence with name `name`."""
    501     if not self.has_rng(name):
--> 502       raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
    503     self._check_valid()
    504     self._validate_trace_level()

InvalidRngError: None needs PRNG for "state" (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.InvalidRngError)
commented

This is expected behavior: setup is called during apply but should normally not create PRNGKeys outside of variable initialization.
The make_rng API is independent of variable initialization for example you could use it like dropout_rng = self.make_rng("dropout").
For initializers we normally use something equivalent toself.variable('state, 'state', lambda *args: init_fn(self.make_rng("state"), *args), shape)

The init_fn will not be called during apply because the variable already has a value.

Thanks, calling make_rng in the initializer on demand solved it for me.

Hello, I don't quite get the idea.
I went through the documentation but I couldn't understand the behavior of make_rng or even nn.compact, it seems that it is doing some magic I can not easily reproduce in a setup(self): definition