kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Generating random numbers – None PRNGKey error

cifkao opened this issue · comments

I tried modifying the model in a way that requires generating random numbers inside the Transformer layer. Specifically, I added a call to hk.next_rng_key() to TransformerLayerShard.__call__ so that I can have a different random number for each batch. This results in the following error during training:

  ...
  File "/home/ocifka/mesh-transformer-jax/mesh_transformer/layers.py", line 312, in __call__
    key = hk.next_rng_key()
  File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 638, in next_rng_key
    return next_rng_key_internal()
  File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 643, in next_rng_key_internal
    rng_seq = rng_seq_or_fail()
  File "/home/ocifka/.local/lib/python3.8/site-packages/haiku/_src/base.py", line 599, in rng_seq_or_fail
    raise ValueError("You must pass a non-None PRNGKey to init and/or apply "
ValueError: You must pass a non-None PRNGKey to init and/or apply if you make use of random numbers.

As far as I can tell, this is due to a None PRNGKey being passed by default due to this line:

train_loss_fn = hk.without_apply_rng(hk.transform(train_loss)).apply

I would appreciate any advice on where I should pass my PRNGKey in order to get a different random number for each training batch.

You can remove the call to without_apply_rng, but you need to feed in a rngkey every time you call train_loss_fn