google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

WARNING:absl:SaveArgs.aggregate is deprecated

jiagaoxiang opened this issue · comments

Hi,

I am receiving this warning WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([], {}, None) values then use PyTreeCheckpointHandler(..., write_tree_metadata=True, ...) or use StandardCheckpointHandler to avoid TypeHandler Registry error. Please note that PyTreeCheckpointHandler.write_tree_metadata default value is already set to T . How to fix this issue?

Below is my code:

#checkpointing
check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
                    step=iter_num,
                    args=ocp.args.Composite(
                        state=ocp.args.StandardSave(state),
                        metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))

orbax-checkpoint version: 0.5.9

Warning looks to be erroneous, probably has to do with your state containing some empty nodes, or similar. I'd say don't worry about it. We're doing some refactoring that should make erroneous warnings like this go away.