FactoredState considered different in optax and flax.serialization
borisdayma opened this issue · comments
Problem you have encountered:
When serializing/deserializing optax state, the Pytree identifies nodes as namedtuple[<class 'optax._src.factorized.FactoredParameterStats'
while when the optimizer applies gradients, those nodes are set as flax.serialization.FactoredState
.
This creates a problem of compilation when trying to load a previous optimizer state if we use jax.lax.cond
and apply gradients conditionally.
What you expected to happen:
I would expect optax._src.factorized.FactoredParameterStats
and flax.serialization.FactoredState
to be considered the same type and not return an error of compilation.
Logs, error messages, etc:
The error I have is:
Traceback (most recent call last):
File "run_seq2seq_flax.py", line 953, in <module>
main()
File "run_seq2seq_flax.py", line 891, in main
state, train_metric = p_train_step(state, batch)
File "run_seq2seq_flax.py", line 728, in train_step
new_state = jax.lax.cond(
TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef(xxxxxxx)
and PyTreeDef(xxxxxxxx)
The content is very long. I compared it and the difference is that one definition uses optax._src.factorized.FactoredState
while the other uses flax.serialization.FactoredState
(it is used at 208 places in each PyTreeDef).
Steps to reproduce:
- checkpoint
state.opt_state
withflax.serialization.to_bytes
- reload
opt_state
withflax.serialization.from_bytes
- use
state.replace(opt_state=opt_state)
(after creating the state) - update gradients conditionally:
new_state = jax.lax.cond(
(state.step + 1) % training_args.gradient_accumulation_steps == 0,
lambda _: state.apply_gradients(.....),
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
None,
)
I tried to create a separate train_step
only for the first step (that does not include any condition) to force the state to always be consistent but this extra function causes a memory error since I already use completely the TPU, even if I just keep one function and use static_broadcasted_argnums
on my pmap function.
I recently fixed this but we need to create a new release on pip. We are a bit short staffed at the moment due to holiday season but will probably make a new release pretty soon. Consider installing the code from main (previously master) as quick fix.
Closing this since it is fixed in our main branch.
For reference: this was fixed in #1432