google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Array has been deleted

fding opened this issue · comments

Hi, we are trying out the orbax (0.4.1) AsyncCheckpointer (used through CheckpointManager). We are getting "Array has been deleted" errors. It seems as if the async checkpointer is trying to copy a jax.Array from device to memory, but that array is no longer available. The Orbax documentations says that "From start to finish, async checkpointing for a train state of arrays works by first performing a blocking copy of the arrays from device to host", but I wonder if there any gotchas in how we should use orbax checkpointing.

Here is the stack trace:

Exception in thread Thread-314 (_finalize):
Traceback (most recent call last):
File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
     self.run()
File "/usr/lib/python3.11/threading.py", line 982, in run
     self._target(*self._args, **self._kwargs)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 956, in _finalize
    self.wait_until_finished(join_finalize_thread=Fale)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 888, in wait_until_finished
     checkpointer.wait_until_finished()  # pytype: disable=attribute-error
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 262, in wait_until_finished
     self._async_manager.wait_until_finished()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 154, in wait_until_finished
     self.check_for_errors()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 145, in check_for_errors
     raise exception  # pylint: disable=raising-bad-type
     ^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 97, in _thread_func
     future.result()
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 456, in result
     return self.__get_result()
            ^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
     raise self._exception
File "/usr/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/aggregate_handlers.py", line 75, in _serialize_fn
     msgpack = msgpack_utils.msgpack_serialize(serializable_dict)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 216, in msgpack_serialize
     return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/msgpack/__init__.py", line 36, in packb
     return Packer(**kwargs).pack(o)
File "msgpack/_packer.pyx", line 285, in msgpack._cmsgpack.Packer._pack
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 78, in _msgpack_ext_pack
     return msgpack.ExtType(_MsgpackExtType.NDARRAY, _ndarray_to_bytes(x))
                                                     ^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 40, in _ndarray_to_bytes
     arr = np.array(arr)
           ^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 377, in __array__
     return np.asarray(self._value, dtype=dtype)
                       ^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 340, in wrapper
     return func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 562, in _value
     self._check_if_deleted()
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 530, in _check_if_deleted
     raise RuntimeError(
RuntimeError: Array has been deleted with shape=float32[256].

Hi David, Can you please try 0.4.7, which might help.

Hi Niket,

I tried upgrading to 0.4.7, and I'm still hitting the error. The only difference is that the error now crashes the training loop, whereas previously training would continue. If I do wait=True, the error disappears, but obviously that's not performant.

Do I need to pass in np.array(jax.experimental.multihost_utils.global_array_to_host_local_array(params)) explicitly? And do I need to be careful about inserting jax.block_until_ready calls?

I did a bit more debugging. I think the error occurs when aggregate=True in the save_args. Reading the code, it seems like in the aggregate=False branch, the type_handlers.ArrayHandler.async_save correctly awaits the copy futures, but when aggregate=True, the threads that are writing the checkpoints access to DeviceArrays directly (https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/aggregate_handlers.py#L72)?

Note that aggregate=True is automatically set by flax.orbax_utils.save_args_from_target(ckpt) for non-sharded params (e.g. all params in a smaller model that is not model sharded or zero sharded). Do you recommend setting aggregate=True? I guess even with a non-sharded model, parameters can be big, so perhaps checking if the array is sharded or not isn't the most accurate heuristic?

Thanks for debugging the issue and associating it with SaveArgs.aggregate option!

While we recreate the issue in our dev setup, please switch to aggregate=False if that works for your use case. SaveArgs.aggregate is mainly meant for performance optimization. I hope that unblocks you.

To recreate the issue, I will need your help:

Do I need to pass in np.array(jax.experimental.multihost_utils.global_array_to_host_local_array(params)) explicitly?
Which api do you want to pass this to? Input to save(...)?

And do I need to be careful about inserting jax.block_until_ready calls?
Which code snippet were you referring to?

Note that aggregate=True is automatically set by flax.orbax_utils.save_args_from_target(ckpt) for non-sharded params (e.g. all params in a smaller model that is not model sharded or zero sharded).
Please note that aggregate=True is set if 1) array is not a jax.Array or 2) it is jax.Array and is_fully_replicated. Please note that it is about replication.

Can you please share a code snippet or pointer which details the Orbax usage, so that I can recreate the issue and resolve it?