google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Struggling to restore metadata on other device

thorben-frank opened this issue · comments

Hello,

I am trying to load metadata on a new device from a checkpoint via CheckpointManager API, but somehow struggle to find a solution. Below is a minimal example of what I am trying to do.

First I do "training" on GPU by running:

from orbax import checkpoint as ocp
import pathlib
import jax.numpy as jnp

ckpt_dir = pathlib.Path('.').expanduser().absolute()

ckpt_mngr = ocp.CheckpointManager(
    ocp.test_utils.create_empty(ckpt_dir / 'checkpoints'),
    item_names=('params', )
)

params = {'a': jnp.array([1.])}

for i in jnp.arange(10):
    ckpt_mngr.save(
        i,
        args=ocp.args.Composite(params=ocp.args.StandardSave(params)),
    )

I then copy the checkpoint to my local machine, which has only CPU available. When I try to get metadata I get the following behaviour.

# Load with the old API
from orbax import checkpoint as ocp
import pathlib
import jax.numpy as jnp

ckpt_dir = pathlib.Path('.').expanduser().absolute()

ckpt_load = ocp.CheckpointManager(
    ckpt_dir / 'checkpoints',
    {'params': ocp.PyTreeCheckpointer()}

)
latest_step = ckpt_load.latest_step()
ckpt_load.item_metadata(0)

Gives me

File ~/Documents/venvs/mlff/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:164, in _deserialize_sharding_from_json_string(sharding_string)
    159   if device := _deserialize_sharding_from_json_string.device_map.get(
    160       device_str, None
    161   ):
    162     return SingleDeviceSharding(device)
--> 164   raise ValueError(
    165       f'{ShardingTypes.SINGLE_DEVICE_SHARDING.value} with'
    166       f' Device={device_str} was not found in jax.local_devices().'
    167   )
    169 else:
    170   raise NotImplementedError(
    171       'Sharding types other than `jax.sharding.NamedSharding` have not been '
    172       'implemented.'
    173   )

ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().

Does that mean that calls to metadata are only available as long as I am on the same device? How else could I get pytree structure without calling model.init itself? Following this issue #648 I could delete the _sharding file and then restore metadata by setting the restore_kwargs appropriately. However, this only works with the old API (see below) and seems a bit hacky to me, so I feel I am doing something wrong here. Using the new API

# Load with new API

ckpt_dir = pathlib.Path('.').expanduser().absolute()

ckpt_load = ocp.CheckpointManager(
    ckpt_dir / 'checkpoints',
    item_names=('params', )
)
latest_step = ckpt_load.latest_step()
ckpt_load.item_metadata(0)

I get

CompositeArgs({})

so no metadata at all.

We're working on a fix to this, unfortunately the sharding metadata doesn't work that well in every case yet. If you must call metadata, just delete the sharding file and continue using the old API for now.

Hi @cpgaffney1! Are there any updates on this matter?

I found a PR with fix for metadata reading, but it was not updated since January 17: #671

Thanks, Simon

Hi, apologies for the long delay on this - we concluded that using jax.Sharding directly in the metadata was a bad decision from the start, since it can't always be loaded correctly. We're adding a new representation of the sharding metadata that doesn't try to interact directly with real devices. You can track changes here: https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/sharding_metadata.py (the latest change doesn't have an external pull request yet). I expect this can be fixed by this week or the next.

cc @liangyaning33 who is working on the implementation.

Hi, sorry about the delay. The issue is now fixed. Can you please try again? Thanks!

Hi, I run into a similar issue.

I save my checkpoints with metrics, train it only on CPU and then on the same machine I want to load a checkpoint. But somehow it looks for a cuda:0 device for metadata. Any help would be greatly appreciated!!
Error:
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().

Checkpoint Manager creation:

options = CheckpointManagerOptions(
                best_fn=lambda metrics: metrics["metric1"],
                best_mode="min",
                max_to_keep=1,
                save_interval_steps=1,
            )

checkpoint_manager = CheckpointManager(
              directory=checkpoint_dir,
              options=options,
          )

Checkpoint saving:

for step in tbar:
            train_batch = generate_batch(datamodule, "train")
            valid_batch = generate_batch(datamodule, "valid")

            state_neural_net, current_logs = step_fn(
                state_neural_net, train_batch, valid_batch
            )

           ckpt =state_neural_net
           checkpoint_manager.save(
                    step,
                    args=StandardSave(ckpt),
                    metrics={
                        "metric1": float(metric1),
                        "metric2": float(metric2),
                        "metric3": float(metric3),
                    },
                )
        checkpoint_manager.wait_until_finished()

And then to load the checkpoint:

# Sets up Ckpt manager as described above
out_class = cls(
            jobid=jobid,
            logger_path=logger_path,
            config=config,
            datamodule=datamodule,
        ) 
        
if step is None:
        # Only checks steps with metrics available
        step = out_class.checkpoint_manager.best_step()
out_class.neural_net = out_class.checkpoint_manager.restore(
            step, args=StandardRestore()
        )

But I get the following output and error:

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py:951, in CheckpointManager.restore(self, step, items, restore_kwargs, directory, args)
    948     args = typing.cast(args_lib.Composite, args)
    950 restore_directory = self._get_read_step_directory(step, directory)
--> 951 restored = self._checkpointer.restore(restore_directory, args=args)
    952 if self._single_item:
    953   return restored[DEFAULT_ITEM_NAME]

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/async_checkpointer.py:338, in AsyncCheckpointer.restore(self, directory, *args, **kwargs)
    336 """See superclass documentation."""
    337 self.wait_until_finished()
--> 338 return super().restore(directory, *args, **kwargs)

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:168, in Checkpointer.restore(self, directory, *args, **kwargs)
    166 logging.info('Restoring item from %s.', directory)
    167 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 168 restored = self._handler.restore(directory, args=ckpt_args)
    169 logging.info('Finished restoring checkpoint from %s.', directory)
    170 return restored

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:464, in CompositeCheckpointHandler.restore(self, directory, args)
    462     continue
    463   handler = self._get_or_set_handler(item_name, arg)
--> 464   restored[item_name] = handler.restore(
    465       self._get_item_directory(directory, item_name), args=arg
    466   )
    467 return CompositeResults(**restored)

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py:166, in StandardCheckpointHandler.restore(self, directory, item, args)
    163   restore_args = checkpoint_utils.construct_restore_args(args.item)
    164 else:
    165   restore_args = checkpoint_utils.construct_restore_args(
--> 166       self.metadata(directory)
    167   )
    168 return super().restore(
    169     directory,
    170     args=pytree_checkpoint_handler.PyTreeRestoreArgs(
    171         item=args.item, restore_args=restore_args
    172     ),
    173 )

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1453, in PyTreeCheckpointHandler.metadata(self, directory)
   1427 """Returns tree metadata.
   1428 
   1429 The result will be a PyTree matching the structure of the saved checkpoint.
   (...)
   1450   tree containing metadata.
   1451 """
   1452 try:
-> 1453   return self._get_user_metadata(directory)
   1454 except FileNotFoundError as e:
   1455   raise FileNotFoundError('Could not locate metadata file.') from e

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1418, in PyTreeCheckpointHandler._get_user_metadata(self, directory)
   1415 async def _get_metadata():
   1416   return await asyncio.gather(*metadata_ops)
-> 1418 batched_metadatas = asyncio.run(_get_metadata())
   1419 for keypath_batch, metadata_batch in zip(
   1420     batched_keypaths.values(), batched_metadatas
   1421 ):
   1422   for keypath, value in zip(keypath_batch, metadata_batch):

File /.conda/envs/condreq/lib/python3.10/site-packages/nest_asyncio.py:30, in _patch_asyncio.<locals>.run(main, debug)
     28 task = asyncio.ensure_future(main)
     29 try:
---> 30     return loop.run_until_complete(task)
     31 finally:
     32     if not task.done():

File /.conda/envs/condreq/lib/python3.10/site-packages/nest_asyncio.py:98, in _patch_loop.<locals>.run_until_complete(self, future)
     95 if not f.done():
     96     raise RuntimeError(
     97         'Event loop stopped before Future completed.')
---> 98 return f.result()

File /.conda/envs/condreq/lib/python3.10/asyncio/futures.py:201, in Future.result(self)
    199 self.__log_traceback = False
    200 if self._exception is not None:
--> 201     raise self._exception.with_traceback(self._exception_tb)
    202 return self._result

File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:234, in Task.__step(***failed resolving arguments***)
    232         result = coro.send(None)
    233     else:
--> 234         result = coro.throw(exc)
    235 except StopIteration as exc:
    236     if self._must_cancel:
    237         # Task is cancelled right before coro stops.

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1416, in PyTreeCheckpointHandler._get_user_metadata.<locals>._get_metadata()
   1415 async def _get_metadata():
-> 1416   return await asyncio.gather(*metadata_ops)

File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:304, in Task.__wakeup(self, future)
    302 def __wakeup(self, future):
    303     try:
--> 304         future.result()
    305     except BaseException as exc:
    306         # This may also be a cancellation.
    307         self.__step(exc)

File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:232, in Task.__step(***failed resolving arguments***)
    228 try:
    229     if exc is None:
    230         # We use the `send` method directly, because coroutines
    231         # don't have `__iter__` and `__next__` methods.
--> 232         result = coro.send(None)
    233     else:
    234         result = coro.throw(exc)

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1403, in ArrayHandler.metadata(self, infos)
   1401     shardings.append(None)
   1402     continue
-> 1403   deserialized = _deserialize_sharding_from_json_string(
   1404       sharding_string.item()
   1405   )
   1406   shardings.append(deserialized or None)
   1407 else:

File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:166, in _deserialize_sharding_from_json_string(sharding_string)
    161   if device := _deserialize_sharding_from_json_string.device_map.get(
    162       device_str, None
    163   ):
    164     return SingleDeviceSharding(device)
--> 166   raise ValueError(
    167       f'{ShardingTypes.SINGLE_DEVICE_SHARDING.value} with'
    168       f' Device={device_str} was not found in jax.local_devices().'
    169   )
    171 else:
    172   raise NotImplementedError(
    173       'Sharding types other than `jax.sharding.NamedSharding` have not been '
    174       'implemented.'
    175   )

ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
datamodule.conditions

Versions:
flax==0.7.4
jax==0.4.20
jaxlib==0.4.20+cuda12.cudnn89
optax==0.1.9
orbax-checkpoint==0.5.7

UPDATE SOLVED For me it was solved by downgrading nvidia-cudnn-cu12-9.1.0.70 to match jaxlib-0.4.20+cuda12.cudnn89. So pip install nvidia-cudnn-cu12==8.9.7.29.

Also note: prefer to specify the shardings for your tree in args=StandardRestore() whenever possible. Either that or specify the restore_type as np.ndarray. https://orbax.readthedocs.io/en/latest/checkpointing_pytrees.html