google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

New interface does not support custom empty pytree class inherited from dict

ZaberKo opened this issue · comments

Reproduction code:

class PyTreeDict(dict):
    pass

jax.tree_util.register_pytree_node(
    PyTreeDict,
    lambda d: (tuple(d.values()), tuple(d.keys())),
    lambda keys, values: PyTreeDict(dict(zip(keys, values)))
)

a={"a": PyTreeDict()} # ValueError: Expected dict, got {}.
# a=PyTreeDict() # ValueError: Found empty item

path = ocp.test_utils.erase_and_create_empty('./debug').resolve()/'ckpt'
ckpt.save(path, a)
ckpt.restore(path, args=ocp.args.StandardRestore(a))

This issue is related to #720 and a066d9c.
@niketkumar

Thanks for reporting, we're looking some refactoring that will resolve these empty node issues.