google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Parse structure of a saved PyTree checkpoint

minotru opened this issue · comments

Hi,

Is there a way to parse structure of a saved PyTree checkpoint?
I found that there is AbstractCheckpointer.structure, but it is deprecated.


CONTEXT:

I have a checkpoint saved with orbax's PyTreeCheckpointHandler, it contains sharded jax.Array-s. I am trying to load a checkpoint on a CPU device, so orbax fails to load a checkpoint, because sharding requires 8 devices, while I have only 1 device -- CPU.

Here is where it fails:
>>> checkpoint_manager.restore(checkpoint_manager.latest_step(), items={"state" : None})
...
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_manager.py", line 472, in restore
    restored_items = self._restore_impl(
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_manager.py", line 504, in _restore_impl
    restored[item_name] = self._checkpointers[item_name].restore(
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpointer.py", line 99, in restore
    restored = self._handler.restore(directory, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1065, in restore
    restored_item = asyncio.run(
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 890, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py", line 1260, in deserialize
    _deserialize_sharding_from_json_string(serialized_string.item())
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py", line 135, in _deserialize_sharding_from_json_string
    np.array(jax.devices()).reshape(shape), axis_names=axis_names
ValueError: cannot reshape array of size 1 into shape (4,2)

To load a checkpoint with different sharding, I need to pass restore_args -- a tree of ArrayRestoreArgs with the same structure, as the saved checkpoint.

The problem is that I do not know the structure of the saved checkpoint, thus I can't create restore_args of proper structure.

Digging into orbax's source code showed that PyTreeCheckpointHandler uses _get_internal_metadata to get item structure, but it is a private method.

So:

  1. What is the right way to load checkpoint with a different sharding without knowledge of checkpoint structure?
  2. Is there a public method to parse checkpoint structure?

Thanks!

Sorry, my issue seems to be a duplicate of #648 and #678, so I will close my issue in favour of those