flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`.
SuwoongHeo opened this issue · comments
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
I`m getting the error with a tutorial example in [20] at https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html. Specifically, the following code yields error
class SimpleMLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
# providing a name is optional though!
# the default autonames would be "Dense_0", "Dense_1", ...
return x
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
Additionally, when I removed random.split
and replace key1
and key2
to random.PRNGKey(0)
, it works fine. I think the error is related to the type of the output of random.split
which is ndarray. Is there any solution to solve this? or is it okay to use random.PRNGKey(0)
rather than random.split
?.
What you expected to happen:
The result as shown in the example:
initialized parameter shapes:
{'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
[[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03
-1.7147182e-02]
[ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02
-4.5417298e-02]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00]
[ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04
-1.0110770e-03]]
Logs, error messages, etc:
Traceback (most recent call last):
File "/ssd2/swheo/dev/hypernerf/flaxtest.py", line 26, in <module>
params = model.init(key2, x)
File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 998, in init
_, v_out = self.init_with_output(
File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 968, in init_with_output
return self.apply(
File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 936, in apply
return apply(
File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 686, in wrapper
with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
File "/ssd2/swheo/dev/Anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 663, in bind
raise errors.InvalidRngError(
flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.InvalidRngError)
Process finished with exit code 1
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
My jax and flax versions are
jax 0.2.24 pypi_0 pypi
jaxlib 0.1.73 pypi_0 pypi
flax 0.3.4 pypi_0 pypi
Thanks for reporting this, but I can't reproduce it:
- The Flax basics Colab executes successfully.
- A Colab I created myself with this code also executes successfully.
Could you please verify it is working in the Colab I created, and explain how to reproduce your problem?
@marcvanzee Thanks for the reply, I moved to version jax 0.2.20
with CUDA 11.1
and it works fine now. I think the problem was a version mismatch between the compiled version of jax
and CUDA
. I was using multiple CUDA
in my path variable (10.2, 11.1, 11.3)
. After removing 11.3
and install using pip with [cuda111]==0.2.0
, the error isn't shown anymore.
FYI, I tried to run the code of your colab in my local PC with CUDA 11.1 and it shows the same error. So I concluded it as a version mismatch problem. So I`m closing this issue by myself. Thanks.