Incorrect null check in pytree_checkpoint_handler.py
hr0nix opened this issue · comments
orbax/checkpoint/pytree_checkpoint_handler.py:661
has the following check: if not item
It most likely should be if item is None
, as otherwise this check will raise an error when item is an array (which is a valid pytree according to pytree definition).
Please use ArrayCheckpointHandler
instead of PyTreeCheckpointHandler
to handle singular arrays.
I think the distinction being made is between PyTree containers and leaves. From the JAX docs: "By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves". To save singular jax.Array or np.ndarray, ArrayCheckpointHandler
is provided instead. Obviously there is some grey area, but I'm hesitant to cram yet more functionality into PyTreeCheckpointHandler
- it has become bloated enough as it is, and we're pushing for simplification in a few key aspects.
Fair enough. In any case, the error was really obscure and the error message in this case can be improved.
Sure that's also fair, I'll make a TODO.