google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Saving doesnt work and results in extra *.npy extension?

allen-adastra opened this issue · comments

Hello!
I'm trying to make orbax play nice with equinox's serialise/deserialise methods but am running into mysterious issues. Essentially, after saving, I get this in my save directory:

awang@mfews-awang:~ $ ls /tmp/tmpsre4jda3/10/default
model.eqx.npy
awang@mfews-awang:~ $ cat /tmp/tmpsre4jda3/10/default/model.eqx.npy 
�NUMPYv{'descr': '<f4', 'fortran_order': False, 'shape': (2,), }                                                            
���>����awang@mfews-awang:~ $ 

My checkpoint handler:

class EquinoxCheckpointHandler(ocp.CheckpointHandler):
    def save(
        self,
        directory: epath.Path,
        args: "EquinoxStateSave",
    ):
        full_path = directory / "model.eqx"
        eqx.tree_serialise_leaves(full_path, args.item, is_leaf=eqx.is_array_like)

    def restore(
        self,
        directory: epath.Path,
        args: "EquinoxStateRestore",
    ) -> eqx.Module:
        loaded = eqx.tree_deserialise_leaves(
            directory / "model.eqx", args.item, is_leaf=eqx.is_array_like
        )
        return loaded


@ocp.args.register_with_handler(EquinoxCheckpointHandler, for_save=True)
@dataclass
class EquinoxStateSave(ocp.args.CheckpointArgs):
    item: eqx.Module


@ocp.args.register_with_handler(EquinoxCheckpointHandler, for_restore=True)
@dataclass
class EquinoxStateRestore(ocp.args.CheckpointArgs):
    item: eqx.Module

and here is my code:

def test_checkpoint_handler():
    checkpoint_dir = tempfile.mkdtemp()

    def build_manager_and_net(key):
        options = ocp.CheckpointManagerOptions(enable_async_checkpointing=False)

        manager = ocp.CheckpointManager(
            directory=checkpoint_dir,
            options=options,
            checkpointers=ocp.Checkpointer(EquinoxCheckpointHandler()),
        )

        net = build_random_nn(key)
        return manager, net

    manager0, net0 = build_manager_and_net(jax.random.PRNGKey(42))

    manager0.save(10, args=EquinoxStateSave(net0), metrics={"loss": 0.1})
    manager0.wait_until_finished()

    manager1, net1 = build_manager_and_net(jax.random.PRNGKey(420))

    net1_restore = manager0.restore(10, args=EquinoxStateRestore(net1))

Could you clarify more what you expect to see? Looks like the checkpoint contains one array of shape (2,) - is it supposed to be a tree of more arrays than that?

Initial thought is that you're mixing old and new APIs in a less-than-ideal way. Since EquinoxCheckpointHandler is registered, you should just be able to do:

manager = ocp.CheckpointManager(
    directory=checkpoint_dir,
    options=options,
)
manager.save(10, args=EquinoxStateSave(net0), metrics={"loss": 0.1})

Yes, the build_random_nn build a MLP with the following structure:

(Pdb) p net0
MLP(
  layers=(
    Linear(
      weight=f32[10,2],
      bias=f32[10],
      in_features=2,
      out_features=10,
      use_bias=True
    ),
    Linear(
      weight=f32[10,10],
      bias=f32[10],
      in_features=10,
      out_features=10,
      use_bias=True
    ),
    Linear(
      weight=f32[2,10],
      bias=f32[2],
      in_features=10,
      out_features=2,
      use_bias=True
    )
  ),
  activation=<wrapped function relu>,
  final_activation=<function <lambda>>,
  use_bias=True,
  use_final_bias=True,
  in_size=2,
  out_size=2,
  width_size=10,
  depth=2
)
(Pdb) 

I tried removing the checkpointer specification like you suggested, and I still have the same issue.

New code:

def build_random_nn(key):
    mlp = eqx.nn.MLP(
        in_size=2,
        out_size=2,
        width_size=10,
        depth=2,
        key=key,
    )
    return mlp



def test_checkpoint_handler():
    checkpoint_dir = tempfile.mkdtemp()

    def build_manager_and_net(key):
        options = ocp.CheckpointManagerOptions(enable_async_checkpointing=False)

        manager = ocp.CheckpointManager(
            directory=checkpoint_dir,
            options=options,
        )

        net = build_random_nn(key)
        return manager, net

    manager0, net0 = build_manager_and_net(jax.random.PRNGKey(42))

    manager0.save(10, args=EquinoxStateSave(net0), metrics={"loss": 0.1})
    manager0.wait_until_finished()

    manager1, net1 = build_manager_and_net(jax.random.PRNGKey(420))

    net1_restore = manager0.restore(10, args=EquinoxStateRestore(net1))  # noqa: F841

Resulting file:

awang@mfews-awang:~ $ ls /tmp/tmpl7gu_xsw/10/default
model.eqx.npy
awang@mfews-awang:~ $ cat /tmp/tmpl7gu_xsw/10/default/model.eqx.npy 
�NUMPYv{'descr': '<f4', 'fortran_order': False, 'shape': (2,), }                                                            
���>����awang@mfews-awang:~ $ 

I'd expect a model.eqx

I hope you have already tried the following by removing Orbax from the picture:

net  = build_random_nn(key)
full_path = directory / "model.eqx"
eqx.tree_serialise_leaves(full_path, net, is_leaf=eqx.is_array_like)

Was the outcome as expected?

Yep. Here's a pytest to reproduce the issue, along with a PyTest showing the usual eqx.tree_serialise_leaves working as intended: https://github.com/allen-adastra/orbax/blob/allenw/repro_eqx_bug/repro_eqx_bug.py

When I check the temp dir used for the test, I see this:

$ ls /tmp/eqx_bug/
10  model_eqx_save.eqx
$ ls /tmp/eqx_bug/10/default/
model.eqx.npy

Thank you for sharing the tests!

Your test helped me to run it locally and identify the issue. If you correct the build_random_nn(...) function as follows, you should not face this odd issue.

def build_random_nn(key: jaxtyping.PRNGKeyArray) -> eqx.nn.MLP:
  return eqx.nn.MLP(in_size=2, out_size=1, width_size=64, depth=2, key=key)

A unit test is attached here.
test.py.txt

Thank you for sharing the tests!

Your test helped me to run it locally and identify the issue. If you correct the build_random_nn(...) function as follows, you should not face this odd issue.

def build_random_nn(key: jaxtyping.PRNGKeyArray) -> eqx.nn.MLP:
  return eqx.nn.MLP(in_size=2, out_size=1, width_size=64, depth=2, key=key)

A unit test is attached here. test.py.txt

Thanks for the quick response! Unfortunately, changing that one function did not fix it for me... I just pushed the unit test with the suggested change:
https://github.com/allen-adastra/orbax/blob/allenw/repro_eqx_bug/repro_eqx_bug.py

Apart from the above type checker issue, I also faced another type checker issue and had to convert full_path to str before passing to equinox serializer.

full_path = directory / 'model.eqx'
eqx.tree_serialise_leaves(
    str(full_path), args.item, is_leaf=eqx.is_array_like
)

Can you please try that?

Apart from the above type checker issue, I also faced another type checker issue and had to convert full_path to str before passing to equinox serializer.

full_path = directory / 'model.eqx'
eqx.tree_serialise_leaves(
    str(full_path), args.item, is_leaf=eqx.is_array_like
)

Can you please try that?

Aha this one fixed it!

So I reverted the change to build_random_nn, and it still works.

So the key fix here was using str(full_path). Well... that seems strange :)

It was strange for sure.

Please consider enabling type checking.

Hmm, I'm curious how type checking would have helped in this situation?