Orbax checkpoint for LogicallyPartitioned params
mmorinag127 opened this issue · comments
Hi all,
I want to ask if there is any way to save parameters with logical partitioning.
I post simple examples of what I've faced.
from jax import random, numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn
class TestModel(nn.Module):
dim: int = 8
@nn.compact
def __call__(self, x, training=True):
out = nn.Dense(self.dim,
kernel_init=nn.with_logical_partitioning(jax.nn.initializers.ones, ('input', 'embed')),
use_bias=False)(x)
# out = nn.BatchNorm(use_running_average=not training)(out)
out = nn.Dense(self.dim,
kernel_init=nn.with_logical_partitioning(jax.nn.initializers.ones, ('embed', 'output')),
use_bias=False)(out)
return out
def test_save(ckpt_dir):
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
if os.path.exists(ckpt_dir):
print(f'{ckpt_dir} was removed')
shutil.rmtree(ckpt_dir)
os.makedirs(ckpt_dir)
print(f'{ckpt_dir} was made')
# A simple model with one linear layer.
device_mesh = mesh_utils.create_device_mesh([4, 2])
mesh = Mesh(devices=device_mesh, axis_names=['data', 'model'])
rng, rng = random.split(random.key(0))
x = random.normal(rng, (64, 32, 4))
x_sharding = NamedSharding(mesh, PartitionSpec('data', ))
x = jax.device_put(x, x_sharding)
model = TestModel(dim=4)
rules = (('embed', 'model'), ('input', None), ('output', None))
def minit(rng, data):
variables = model.init(rng, data)
return variables
abst_var = jax.eval_shape(minit, rng, x)
spec_var = nn.get_partition_spec(abst_var)
var_sharding = nn.logical_to_mesh_sharding(spec_var, mesh, rules)
rng_sharding = NamedSharding(mesh, None)
variables = jax.jit(minit, static_argnums=(), in_shardings=(rng_sharding, x_sharding), out_shardings=var_sharding)(rng, x)
print('save')
print(variables['params']['Dense_0']['kernel'])
options = ocp.CheckpointManagerOptions(
max_to_keep=2,
create=True,
best_mode='min',
best_fn=lambda m: m['loss'],
)
ckpt_mgr = ocp.CheckpointManager(ckpt_dir, options=options, item_handlers=ocp.StandardCheckpointHandler())
for step in range(5):
loss = (step - 2)**2 + 0.001
# save_args = orbax_utils.save_args_from_target(variables)
save_args = ocp.args.StandardSave(variables)
ckpt_mgr.save(step, args=save_args, metrics={'loss': loss})
def test_restore(ckpt_dir):
config = {'create': True, 'keep_period': 1, 'best_mode': 'min'}
options = ocp.CheckpointManagerOptions(**config, best_fn=lambda metrics: metrics['loss'])
ckpt_mgr = ocp.CheckpointManager(ckpt_dir, options=options, item_handlers=ocp.StandardCheckpointHandler())
variables = ckpt_mgr.restore(ckpt_mgr.best_step())
print('restore')
print(variables['params']['Dense_0']['kernel'])
With this example, I got the following:
save
LogicallyPartitioned(value=Array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32), names=('input', 'embed'), mesh=None, rules=None)
restore
{'value': Array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)}
It seems that the restored parameters differ from what I want to save.
Because I would like to use these parameters in another model, i.e. transfer learning.
Is there any way to avoid this problem?
This is because LogicallyPartitioned
is a runtime object, and you'll need it to be available when you load the checkpoint, likely via this item
arg of Orbax CheckpointManager
: example use
Another option is to save the raw dictionary of JAX arrays as checkpoint, which might be more intuitive. Like:
- Instead of saving
variables
, savenn.meta.unbox(variables)
- When restoring, you'll get a dict of raw JAX arrays. You need to get the annotations back, like via
abst_var = jax.eval_shape(minit, rng, x)
- run
nn.meta.replace_boxed(abst_var, from_ckpt)
to get boxed variables back.
I hope this helps!
Thanks a lot!
Indeed, it works in my case.
It would be better to have some documents to address this hidden method.
Anyway, thank you!