Generating random numbers – None PRNGKey error
cifkao opened this issue · comments
I tried modifying the model in a way that requires generating random numbers inside the Transformer layer. Specifically, I added a call to hk.next_rng_key()
to TransformerLayerShard.__call__
so that I can have a different random number for each batch. This results in the following error during training:
...
File "/home/ocifka/mesh-transformer-jax/mesh_transformer/layers.py", line 312, in __call__
key = hk.next_rng_key()
File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 638, in next_rng_key
return next_rng_key_internal()
File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 643, in next_rng_key_internal
rng_seq = rng_seq_or_fail()
File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 599, in rng_seq_or_fail
raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.
As far as I can tell, this is due to a None
PRNGKey being passed by default due to this line:
I would appreciate any advice on where I should pass my PRNGKey in order to get a different random number for each training batch.
You can remove the call to without_apply_rng, but you need to feed in a rngkey every time you call train_loss_fn