google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`.

SuwoongHeo opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

I`m getting the error with a tutorial example in [20] at https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html. Specifically, the following code yields error

class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

Additionally, when I removed random.split and replace key1 and key2 to random.PRNGKey(0), it works fine. I think the error is related to the type of the output of random.split which is ndarray. Is there any solution to solve this? or is it okay to use random.PRNGKey(0) rather than random.split?.

What you expected to happen:

The result as shown in the example:

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292815e-02 -4.3807115e-02  2.9323792e-02  6.5492536e-03
  -1.7147182e-02]
 [ 1.2967806e-01 -1.4551792e-01  9.4432183e-02  1.2521387e-02
  -4.5417298e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024032e-04  2.7864395e-05  2.4478821e-04  8.1344310e-04
  -1.0110770e-03]]

Logs, error messages, etc:

Traceback (most recent call last):
  File "/ssd2/swheo/dev/hypernerf/flaxtest.py", line 26, in <module>
    params = model.init(key2, x)
  File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 998, in init
    _, v_out = self.init_with_output(
  File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 968, in init_with_output
    return self.apply(
  File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 936, in apply
    return apply(
  File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 686, in wrapper
    with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
  File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 663, in bind
    raise errors.InvalidRngError(
flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.InvalidRngError)

Process finished with exit code 1

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

My jax and flax versions are

jax 0.2.24 pypi_0 pypi
jaxlib 0.1.73 pypi_0 pypi
flax 0.3.4 pypi_0 pypi

Thanks for reporting this, but I can't reproduce it:

Could you please verify it is working in the Colab I created, and explain how to reproduce your problem?

@marcvanzee Thanks for the reply, I moved to version jax 0.2.20 with CUDA 11.1 and it works fine now. I think the problem was a version mismatch between the compiled version of jax and CUDA. I was using multiple CUDA in my path variable (10.2, 11.1, 11.3). After removing 11.3 and install using pip with [cuda111]==0.2.0, the error isn't shown anymore.

FYI, I tried to run the code of your colab in my local PC with CUDA 11.1 and it shows the same error. So I concluded it as a version mismatch problem. So I`m closing this issue by myself. Thanks.