google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`FLAX_PROFILE=1` changes the computation

rsepassi opened this issue · comments

Problem you have encountered:

run.py:
print(nn.Dense(1, use_bias=False).init({'params': jax.random.PRNGKey(0)}, np.ones((1, 1), np.float32)))

FLAX_PROFILE= python run.py -> -1.4588
FLAX_PROFILE=1 python run.py -> 1.3333

What you expected to happen:

I would expect a profiling option like FLAX_PROFILE to not change any computation.

commented

Ah yes the RNGs get split differently as a side effect of named_call we should definitely fix that

Thanks!