google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FactoredState considered different in optax and flax.serialization

borisdayma opened this issue · comments

Problem you have encountered:

When serializing/deserializing optax state, the Pytree identifies nodes as namedtuple[<class 'optax._src.factorized.FactoredParameterStats' while when the optimizer applies gradients, those nodes are set as flax.serialization.FactoredState.

This creates a problem of compilation when trying to load a previous optimizer state if we use jax.lax.cond and apply gradients conditionally.

What you expected to happen:

I would expect optax._src.factorized.FactoredParameterStats and flax.serialization.FactoredState to be considered the same type and not return an error of compilation.

Logs, error messages, etc:

The error I have is:

Traceback (most recent call last):
  File "run_seq2seq_flax.py", line 953, in <module>
    main()
  File "run_seq2seq_flax.py", line 891, in main
    state, train_metric = p_train_step(state, batch)
  File "run_seq2seq_flax.py", line 728, in train_step
    new_state = jax.lax.cond(
TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef(xxxxxxx)
and PyTreeDef(xxxxxxxx)

The content is very long. I compared it and the difference is that one definition uses optax._src.factorized.FactoredState while the other uses flax.serialization.FactoredState (it is used at 208 places in each PyTreeDef).

Steps to reproduce:

  • checkpoint state.opt_state with flax.serialization.to_bytes
  • reload opt_state with flax.serialization.from_bytes
  • use state.replace(opt_state=opt_state) (after creating the state)
  • update gradients conditionally:
new_state = jax.lax.cond(
    (state.step + 1) % training_args.gradient_accumulation_steps == 0,
    lambda _: state.apply_gradients(.....),
    lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
    None,
)

I tried to create a separate train_step only for the first step (that does not include any condition) to force the state to always be consistent but this extra function causes a memory error since I already use completely the TPU, even if I just keep one function and use static_broadcasted_argnums on my pmap function.

commented

I recently fixed this but we need to create a new release on pip. We are a bit short staffed at the moment due to holiday season but will probably make a new release pretty soon. Consider installing the code from main (previously master) as quick fix.

Closing this since it is fixed in our main branch.

For reference: this was fixed in #1432