google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Orbax checkpoint for LogicallyPartitioned params

mmorinag127 opened this issue · comments

Hi all,

I want to ask if there is any way to save parameters with logical partitioning.
I post simple examples of what I've faced.

from jax import random, numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn


class TestModel(nn.Module):
    dim: int = 8
    
    @nn.compact
    def __call__(self, x, training=True):
        out = nn.Dense(self.dim,
                       kernel_init=nn.with_logical_partitioning(jax.nn.initializers.ones, ('input', 'embed')),
                       use_bias=False)(x)
        # out = nn.BatchNorm(use_running_average=not training)(out)
        out = nn.Dense(self.dim,
                       kernel_init=nn.with_logical_partitioning(jax.nn.initializers.ones, ('embed', 'output')),
                       use_bias=False)(out)
        return out

def test_save(ckpt_dir):
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
    if os.path.exists(ckpt_dir):
        print(f'{ckpt_dir} was removed')
        shutil.rmtree(ckpt_dir)
    os.makedirs(ckpt_dir)
    print(f'{ckpt_dir} was made')
    
    # A simple model with one linear layer.
    device_mesh = mesh_utils.create_device_mesh([4, 2])
    mesh = Mesh(devices=device_mesh, axis_names=['data', 'model'])
    
    rng, rng = random.split(random.key(0))
    x = random.normal(rng, (64, 32, 4))
    x_sharding = NamedSharding(mesh, PartitionSpec('data', ))
    x = jax.device_put(x, x_sharding)
    model = TestModel(dim=4)
    rules = (('embed', 'model'), ('input', None), ('output', None))
    
    def minit(rng, data):
        variables = model.init(rng, data)
        return variables
    
    abst_var = jax.eval_shape(minit, rng, x)
    spec_var = nn.get_partition_spec(abst_var)
    var_sharding = nn.logical_to_mesh_sharding(spec_var, mesh, rules)
    rng_sharding = NamedSharding(mesh, None)
    variables = jax.jit(minit, static_argnums=(), in_shardings=(rng_sharding, x_sharding), out_shardings=var_sharding)(rng, x)
    print('save')
    print(variables['params']['Dense_0']['kernel'])
    
    options = ocp.CheckpointManagerOptions(
        max_to_keep=2,
        create=True,
        best_mode='min',
        best_fn=lambda m: m['loss'],
    )
    ckpt_mgr = ocp.CheckpointManager(ckpt_dir, options=options, item_handlers=ocp.StandardCheckpointHandler())
    for step in range(5):
        loss = (step - 2)**2 + 0.001
        # save_args = orbax_utils.save_args_from_target(variables)
        save_args = ocp.args.StandardSave(variables)
        ckpt_mgr.save(step, args=save_args, metrics={'loss': loss})


def test_restore(ckpt_dir):
    config = {'create': True, 'keep_period': 1, 'best_mode': 'min'}
    
    options = ocp.CheckpointManagerOptions(**config, best_fn=lambda metrics: metrics['loss'])
    ckpt_mgr = ocp.CheckpointManager(ckpt_dir, options=options, item_handlers=ocp.StandardCheckpointHandler())
    variables = ckpt_mgr.restore(ckpt_mgr.best_step())
    print('restore')
    print(variables['params']['Dense_0']['kernel'])

With this example, I got the following:

save                                                                                                                                                                                       
LogicallyPartitioned(value=Array([[1., 1., 1., 1.],                                                                                                                                        
       [1., 1., 1., 1.],                                                                                                                                                                   
       [1., 1., 1., 1.],                                                                                                                                                                   
       [1., 1., 1., 1.]], dtype=float32), names=('input', 'embed'), mesh=None, rules=None)                                                                                                 
restore                                                                                                                                                                                    
{'value': Array([[1., 1., 1., 1.],                                                                                                                                                         
       [1., 1., 1., 1.],                                                                                                                                                                   
       [1., 1., 1., 1.],                                                                                                                                                                   
       [1., 1., 1., 1.]], dtype=float32)}       

It seems that the restored parameters differ from what I want to save.
Because I would like to use these parameters in another model, i.e. transfer learning.

Is there any way to avoid this problem?

This is because LogicallyPartitioned is a runtime object, and you'll need it to be available when you load the checkpoint, likely via this item arg of Orbax CheckpointManager: example use

Another option is to save the raw dictionary of JAX arrays as checkpoint, which might be more intuitive. Like:

  1. Instead of saving variables, save nn.meta.unbox(variables)
  2. When restoring, you'll get a dict of raw JAX arrays. You need to get the annotations back, like via abst_var = jax.eval_shape(minit, rng, x)
  3. run nn.meta.replace_boxed(abst_var, from_ckpt) to get boxed variables back.

I hope this helps!

Thanks a lot!
Indeed, it works in my case.
It would be better to have some documents to address this hidden method.
Anyway, thank you!