Saving doesnt work and results in extra *.npy extension?
allen-adastra opened this issue · comments
Hello!
I'm trying to make orbax play nice with equinox's serialise/deserialise methods but am running into mysterious issues. Essentially, after saving, I get this in my save directory:
awang@mfews-awang:~ $ ls /tmp/tmpsre4jda3/10/default
model.eqx.npy
awang@mfews-awang:~ $ cat /tmp/tmpsre4jda3/10/default/model.eqx.npy
�NUMPYv{'descr': '<f4', 'fortran_order': False, 'shape': (2,), }
���>����awang@mfews-awang:~ $
My checkpoint handler:
class EquinoxCheckpointHandler(ocp.CheckpointHandler):
def save(
self,
directory: epath.Path,
args: "EquinoxStateSave",
):
full_path = directory / "model.eqx"
eqx.tree_serialise_leaves(full_path, args.item, is_leaf=eqx.is_array_like)
def restore(
self,
directory: epath.Path,
args: "EquinoxStateRestore",
) -> eqx.Module:
loaded = eqx.tree_deserialise_leaves(
directory / "model.eqx", args.item, is_leaf=eqx.is_array_like
)
return loaded
@ocp.args.register_with_handler(EquinoxCheckpointHandler, for_save=True)
@dataclass
class EquinoxStateSave(ocp.args.CheckpointArgs):
item: eqx.Module
@ocp.args.register_with_handler(EquinoxCheckpointHandler, for_restore=True)
@dataclass
class EquinoxStateRestore(ocp.args.CheckpointArgs):
item: eqx.Module
and here is my code:
def test_checkpoint_handler():
checkpoint_dir = tempfile.mkdtemp()
def build_manager_and_net(key):
options = ocp.CheckpointManagerOptions(enable_async_checkpointing=False)
manager = ocp.CheckpointManager(
directory=checkpoint_dir,
options=options,
checkpointers=ocp.Checkpointer(EquinoxCheckpointHandler()),
)
net = build_random_nn(key)
return manager, net
manager0, net0 = build_manager_and_net(jax.random.PRNGKey(42))
manager0.save(10, args=EquinoxStateSave(net0), metrics={"loss": 0.1})
manager0.wait_until_finished()
manager1, net1 = build_manager_and_net(jax.random.PRNGKey(420))
net1_restore = manager0.restore(10, args=EquinoxStateRestore(net1))
Could you clarify more what you expect to see? Looks like the checkpoint contains one array of shape (2,) - is it supposed to be a tree of more arrays than that?
Initial thought is that you're mixing old and new APIs in a less-than-ideal way. Since EquinoxCheckpointHandler is registered, you should just be able to do:
manager = ocp.CheckpointManager(
directory=checkpoint_dir,
options=options,
)
manager.save(10, args=EquinoxStateSave(net0), metrics={"loss": 0.1})
Yes, the build_random_nn
build a MLP with the following structure:
(Pdb) p net0
MLP(
layers=(
Linear(
weight=f32[10,2],
bias=f32[10],
in_features=2,
out_features=10,
use_bias=True
),
Linear(
weight=f32[10,10],
bias=f32[10],
in_features=10,
out_features=10,
use_bias=True
),
Linear(
weight=f32[2,10],
bias=f32[2],
in_features=10,
out_features=2,
use_bias=True
)
),
activation=<wrapped function relu>,
final_activation=<function <lambda>>,
use_bias=True,
use_final_bias=True,
in_size=2,
out_size=2,
width_size=10,
depth=2
)
(Pdb)
I tried removing the checkpointer specification like you suggested, and I still have the same issue.
New code:
def build_random_nn(key):
mlp = eqx.nn.MLP(
in_size=2,
out_size=2,
width_size=10,
depth=2,
key=key,
)
return mlp
def test_checkpoint_handler():
checkpoint_dir = tempfile.mkdtemp()
def build_manager_and_net(key):
options = ocp.CheckpointManagerOptions(enable_async_checkpointing=False)
manager = ocp.CheckpointManager(
directory=checkpoint_dir,
options=options,
)
net = build_random_nn(key)
return manager, net
manager0, net0 = build_manager_and_net(jax.random.PRNGKey(42))
manager0.save(10, args=EquinoxStateSave(net0), metrics={"loss": 0.1})
manager0.wait_until_finished()
manager1, net1 = build_manager_and_net(jax.random.PRNGKey(420))
net1_restore = manager0.restore(10, args=EquinoxStateRestore(net1)) # noqa: F841
Resulting file:
awang@mfews-awang:~ $ ls /tmp/tmpl7gu_xsw/10/default
model.eqx.npy
awang@mfews-awang:~ $ cat /tmp/tmpl7gu_xsw/10/default/model.eqx.npy
�NUMPYv{'descr': '<f4', 'fortran_order': False, 'shape': (2,), }
���>����awang@mfews-awang:~ $
I'd expect a model.eqx
I hope you have already tried the following by removing Orbax from the picture:
net = build_random_nn(key)
full_path = directory / "model.eqx"
eqx.tree_serialise_leaves(full_path, net, is_leaf=eqx.is_array_like)
Was the outcome as expected?
Yep. Here's a pytest to reproduce the issue, along with a PyTest showing the usual eqx.tree_serialise_leaves working as intended: https://github.com/allen-adastra/orbax/blob/allenw/repro_eqx_bug/repro_eqx_bug.py
When I check the temp dir used for the test, I see this:
$ ls /tmp/eqx_bug/
10 model_eqx_save.eqx
$ ls /tmp/eqx_bug/10/default/
model.eqx.npy
Thank you for sharing the tests!
Your test helped me to run it locally and identify the issue. If you correct the build_random_nn(...)
function as follows, you should not face this odd issue.
def build_random_nn(key: jaxtyping.PRNGKeyArray) -> eqx.nn.MLP:
return eqx.nn.MLP(in_size=2, out_size=1, width_size=64, depth=2, key=key)
A unit test is attached here.
test.py.txt
Thank you for sharing the tests!
Your test helped me to run it locally and identify the issue. If you correct the
build_random_nn(...)
function as follows, you should not face this odd issue.def build_random_nn(key: jaxtyping.PRNGKeyArray) -> eqx.nn.MLP: return eqx.nn.MLP(in_size=2, out_size=1, width_size=64, depth=2, key=key)
A unit test is attached here. test.py.txt
Thanks for the quick response! Unfortunately, changing that one function did not fix it for me... I just pushed the unit test with the suggested change:
https://github.com/allen-adastra/orbax/blob/allenw/repro_eqx_bug/repro_eqx_bug.py
Apart from the above type checker issue, I also faced another type checker issue and had to convert full_path
to str
before passing to equinox serializer.
full_path = directory / 'model.eqx'
eqx.tree_serialise_leaves(
str(full_path), args.item, is_leaf=eqx.is_array_like
)
Can you please try that?
Apart from the above type checker issue, I also faced another type checker issue and had to convert
full_path
tostr
before passing to equinox serializer.full_path = directory / 'model.eqx' eqx.tree_serialise_leaves( str(full_path), args.item, is_leaf=eqx.is_array_like )
Can you please try that?
Aha this one fixed it!
So I reverted the change to build_random_nn
, and it still works.
So the key fix here was using str(full_path)
. Well... that seems strange :)
It was strange for sure.
Please consider enabling type checking.
Hmm, I'm curious how type checking would have helped in this situation?