google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Checkpoint Manager using different directory paths for save and restore

svarunid opened this issue · comments

I was trying to save and restore model and opt_state using checkpoint manager. I noticed two issues. While saving the checkpoint manager creates a temp directory path and saves in that location.

tmp_step_dir = self._create_tmp_directory(save_directory)

This temporary directory adds an extra time stamp to the path we pass in during the initialization of checkpoint manager.
However, while restoring, the correct directory path is not resolved and this causes an directory not found issue.

directory = self.directory
path = self._get_save_directory(step, directory, item_name)
self._checkpointers[item_name].restore(
          path, item=item, **kwargs
)

Here's my code:

epath.Path('/nmt-attention-checkpoints/')

mngr_options = ocp.CheckpointManagerOptions(
    max_to_keep=3,
    save_interval_steps=25
)

mngr = ocp.CheckpointManager(
    path,
    {
        "model": ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler()),
         "opt_state": ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
    },
    mngr_options
)

mngr.save(
    0,
    {
        "model": model,
        "opt_state": opt_state
    }
)

mngr.restore(step=mngr.latest_step())

Path were the checkpoints are saved: nmt-attention-checkpoints\0.orbax-checkpoint-tmp-1703334121930256

I get the follwing error while restoring my checkpoints: Checkpoint at \nmt-attention-checkpoints\0\model not found.

You're checkpointing asynchronously and restoring without waiting for the background save operation to complete. Add a wait_until_finished call before restoring. The orbax-checkpoint-tmp... suffix indicates that the checkpoint is not complete and cannot yet be restored (or if there was a failure before finalization, the checkpoint is likely garbage).

Thanks! I didn't notice that CheckpointManager has its own wait_until_finished method.