`FLAX_PROFILE=1` changes the computation
rsepassi opened this issue · comments
Ryan Sepassi commented
Problem you have encountered:
run.py:
print(nn.Dense(1, use_bias=False).init({'params': jax.random.PRNGKey(0)}, np.ones((1, 1), np.float32)))
FLAX_PROFILE= python run.py -> -1.4588
FLAX_PROFILE=1 python run.py -> 1.3333
What you expected to happen:
I would expect a profiling option like FLAX_PROFILE
to not change any computation.
jheek commented
Ah yes the RNGs get split differently as a side effect of named_call we should definitely fix that
Ryan Sepassi commented
Thanks!