Array has been deleted
fding opened this issue · comments
Hi, we are trying out the orbax (0.4.1) AsyncCheckpointer (used through CheckpointManager). We are getting "Array has been deleted" errors. It seems as if the async checkpointer is trying to copy a jax.Array from device to memory, but that array is no longer available. The Orbax documentations says that "From start to finish, async checkpointing for a train state of arrays works by first performing a blocking copy of the arrays from device to host", but I wonder if there any gotchas in how we should use orbax checkpointing.
Here is the stack trace:
Exception in thread Thread-314 (_finalize):
Traceback (most recent call last):
File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
self.run()
File "/usr/lib/python3.11/threading.py", line 982, in run
self._target(*self._args, **self._kwargs)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 956, in _finalize
self.wait_until_finished(join_finalize_thread=Fale)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 888, in wait_until_finished
checkpointer.wait_until_finished() # pytype: disable=attribute-error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 262, in wait_until_finished
self._async_manager.wait_until_finished()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 154, in wait_until_finished
self.check_for_errors()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 145, in check_for_errors
raise exception # pylint: disable=raising-bad-type
^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 97, in _thread_func
future.result()
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/usr/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/aggregate_handlers.py", line 75, in _serialize_fn
msgpack = msgpack_utils.msgpack_serialize(serializable_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 216, in msgpack_serialize
return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/msgpack/__init__.py", line 36, in packb
return Packer(**kwargs).pack(o)
File "msgpack/_packer.pyx", line 285, in msgpack._cmsgpack.Packer._pack
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 78, in _msgpack_ext_pack
return msgpack.ExtType(_MsgpackExtType.NDARRAY, _ndarray_to_bytes(x))
^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 40, in _ndarray_to_bytes
arr = np.array(arr)
^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 377, in __array__
return np.asarray(self._value, dtype=dtype)
^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 340, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 562, in _value
self._check_if_deleted()
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 530, in _check_if_deleted
raise RuntimeError(
RuntimeError: Array has been deleted with shape=float32[256].
Hi David, Can you please try 0.4.7, which might help.
Hi Niket,
I tried upgrading to 0.4.7, and I'm still hitting the error. The only difference is that the error now crashes the training loop, whereas previously training would continue. If I do wait=True, the error disappears, but obviously that's not performant.
Do I need to pass in np.array(jax.experimental.multihost_utils.global_array_to_host_local_array(params))
explicitly? And do I need to be careful about inserting jax.block_until_ready
calls?
I did a bit more debugging. I think the error occurs when aggregate=True in the save_args. Reading the code, it seems like in the aggregate=False branch, the type_handlers.ArrayHandler.async_save
correctly awaits the copy futures, but when aggregate=True, the threads that are writing the checkpoints access to DeviceArrays directly (https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/aggregate_handlers.py#L72)?
Note that aggregate=True is automatically set by flax.orbax_utils.save_args_from_target(ckpt)
for non-sharded params (e.g. all params in a smaller model that is not model sharded or zero sharded). Do you recommend setting aggregate=True? I guess even with a non-sharded model, parameters can be big, so perhaps checking if the array is sharded or not isn't the most accurate heuristic?
Thanks for debugging the issue and associating it with SaveArgs.aggregate option!
While we recreate the issue in our dev setup, please switch to aggregate=False if that works for your use case. SaveArgs.aggregate is mainly meant for performance optimization. I hope that unblocks you.
To recreate the issue, I will need your help:
Do I need to pass in np.array(jax.experimental.multihost_utils.global_array_to_host_local_array(params)) explicitly?
Which api do you want to pass this to? Input tosave(...)
?
And do I need to be careful about inserting jax.block_until_ready calls?
Which code snippet were you referring to?
Note that aggregate=True is automatically set by flax.orbax_utils.save_args_from_target(ckpt) for non-sharded params (e.g. all params in a smaller model that is not model sharded or zero sharded).
Please note that aggregate=True is set if 1) array is not a jax.Array or 2) it is jax.Array andis_fully_replicated
. Please note that it is about replication.
Can you please share a code snippet or pointer which details the Orbax usage, so that I can recreate the issue and resolve it?