google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Deprecation Warnings with orbax 0.5.3

lucidfrontier45 opened this issue · comments

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Debian 12
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax 0.8.1, orbax-checkpoint 0.5.3
  • Python version: 3.11.4

Problem you have encountered:

I followed save and load checkpoints tutorial and I got deprecation warnings. Although checkpoints were saved correctly, it would be great if the latest correct way of saving/loading Flax TraningState is documented in the tutorial.

What you expected to happen:

no warnings

Logs, error messages, etc:

WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024.

Steps to reproduce:

Just follow save and load checkpoints tutorial

Any thoughts on how to checkpoint a Flax TrainState with the new CheckpointManager API? I gave it a go on StackOverflow, but without success.

Hi - one of the Orbax team member replied your StackOverflow thread, please take a look.

Regarding the SaveArgs.aggregate is deprecated warning, it seems that internally aggregate=True will happen whenever Orbax tries to save an empty node like optax.EmptyNode() (which is part of the TrainState as optimizer state).
Orbax team would will work on an refactoring that removes the use of aggregate internally. Meanwhile, the whole TrainState will still be saved correctly despite the deprecation warning, so it probably would not affect your use.

I will make another PR that addresses the deprecated CheckpointManager API warning in the Flax Orbax guide.

Hi,

I am still receiving this warning WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([], {}, None) values then use PyTreeCheckpointHandler(..., write_tree_metadata=True, ...) or use StandardCheckpointHandler to avoid TypeHandler Registry error. Please note that PyTreeCheckpointHandler.write_tree_metadata default value is already set to T . Has the fix been implemented?

Below is my code:

#checkpointing
check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
                    step=iter_num,
                    args=ocp.args.Composite(
                        state=ocp.args.StandardSave(state),
                        metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))

orbax-checkpoint version: 0.5.9