google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Incorrect null check in pytree_checkpoint_handler.py

hr0nix opened this issue · comments

orbax/checkpoint/pytree_checkpoint_handler.py:661 has the following check: if not item

It most likely should be if item is None, as otherwise this check will raise an error when item is an array (which is a valid pytree according to pytree definition).

Please use ArrayCheckpointHandler instead of PyTreeCheckpointHandler to handle singular arrays.

I think the distinction being made is between PyTree containers and leaves. From the JAX docs: "By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves". To save singular jax.Array or np.ndarray, ArrayCheckpointHandler is provided instead. Obviously there is some grey area, but I'm hesitant to cram yet more functionality into PyTreeCheckpointHandler - it has become bloated enough as it is, and we're pushing for simplification in a few key aspects.

Fair enough. In any case, the error was really obscure and the error message in this case can be improved.

Sure that's also fair, I'll make a TODO.