NTT123 / opax

PAX optimizer library

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Freezing parameters for transfer learning

siasio opened this issue · comments

commented

In optax, there is a function .multi_transform which can be used to apply different optimization functions to different parameters in a model. In particular, it's possible to tell the optimizer to not update gradients of certain parameters. This way, one could e.g. freeze a backbone of a pre-trained model during transfer learning.

Is there a similar functionality in opax? How to freeze model's backbone and train only its head?

commented

In optax, there is a function .multi_transform which can be used to apply different optimization functions to different parameters in a model. In particular, it's possible to tell the optimizer to not update gradients of certain parameters. This way, one could e.g. freeze a backbone of a pre-trained model during transfer learning.

Is there a similar functionality in opax? How to freeze model's backbone and train only its head?

What is needed, is a function following the conventions from transform.py, e.g. like this one: https://github.com/NTT123/opax/blob/main/opax/transform.py#L28

We want to have a different scaling of gradient updates for different parts of the network. I dealt with it by keeping the modules which the network consists of (backbone and head) in a dictionary:

class TransferResnet(pax.Module):
    def __init__(self, backbone, head):
        self.module_dict = {"backbone": backbone, "head": head}
    def __call__(self, input):
        x = self.module_dict["backbone"](input)
        x = self.module_dict["head"](x)
        return x

This way, the module names will be present in a pytree so we can reference them in a function jax.tree_util.tree_map_with_path.

I wanted to freeze the backbone parameters at the first stage of training and unfreeze at a later stage so I went for a function:

def multi_transform(schedule_fn: Callable[[jnp.ndarray], jnp.ndarray]):
    count: jnp.ndarray
    backbone_multiplier: jnp.ndarray

    class MultiTransform(GradientTransformation):
        def __init__(self, params):
            super().__init__(params=params)
            self.schedule_fn = schedule_fn
            self.count = jnp.array(0, dtype=jnp.int32)
            self.backbone_multiplier = self.schedule_fn(self.count)

        def __call__(self, updates, params=None):
            del params
            self.count = self.count + 1
            self.backbone_multiplier = self.schedule_fn(self.count)

            updates = jax.tree_util.tree_map_with_path(
                lambda path, u: self.backbone_multiplier * u if "backbone" in jax.tree_util.keystr(path) else u, updates
            )
            return updates

    return MultiTransform

My schedule function is very simple, returning 0 for first backbone_lr_steps steps, and 1 afterwards:

def lr_backbone_schedule(step):
    return step > backbone_lr_steps

The only thing which is left, is adding the function to an optimizer:

optim = opax.chain(
        opax.add_decayed_weights(weight_decay),
        opax.sgd(lr_schedule, momentum=0.9),
        multi_transform(lr_backbone_schedule)
    ).init(transfer_model.parameters())