Freezing parameters for transfer learning
siasio opened this issue · comments
In optax, there is a function .multi_transform which can be used to apply different optimization functions to different parameters in a model. In particular, it's possible to tell the optimizer to not update gradients of certain parameters. This way, one could e.g. freeze a backbone of a pre-trained model during transfer learning.
Is there a similar functionality in opax? How to freeze model's backbone and train only its head?
In optax, there is a function .multi_transform which can be used to apply different optimization functions to different parameters in a model. In particular, it's possible to tell the optimizer to not update gradients of certain parameters. This way, one could e.g. freeze a backbone of a pre-trained model during transfer learning.
Is there a similar functionality in opax? How to freeze model's backbone and train only its head?
What is needed, is a function following the conventions from transform.py, e.g. like this one: https://github.com/NTT123/opax/blob/main/opax/transform.py#L28
We want to have a different scaling of gradient updates for different parts of the network. I dealt with it by keeping the modules which the network consists of (backbone and head) in a dictionary:
class TransferResnet(pax.Module):
def __init__(self, backbone, head):
self.module_dict = {"backbone": backbone, "head": head}
def __call__(self, input):
x = self.module_dict["backbone"](input)
x = self.module_dict["head"](x)
return x
This way, the module names will be present in a pytree so we can reference them in a function jax.tree_util.tree_map_with_path
.
I wanted to freeze the backbone parameters at the first stage of training and unfreeze at a later stage so I went for a function:
def multi_transform(schedule_fn: Callable[[jnp.ndarray], jnp.ndarray]):
count: jnp.ndarray
backbone_multiplier: jnp.ndarray
class MultiTransform(GradientTransformation):
def __init__(self, params):
super().__init__(params=params)
self.schedule_fn = schedule_fn
self.count = jnp.array(0, dtype=jnp.int32)
self.backbone_multiplier = self.schedule_fn(self.count)
def __call__(self, updates, params=None):
del params
self.count = self.count + 1
self.backbone_multiplier = self.schedule_fn(self.count)
updates = jax.tree_util.tree_map_with_path(
lambda path, u: self.backbone_multiplier * u if "backbone" in jax.tree_util.keystr(path) else u, updates
)
return updates
return MultiTransform
My schedule function is very simple, returning 0 for first backbone_lr_steps
steps, and 1 afterwards:
def lr_backbone_schedule(step):
return step > backbone_lr_steps
The only thing which is left, is adding the function to an optimizer:
optim = opax.chain(
opax.add_decayed_weights(weight_decay),
opax.sgd(lr_schedule, momentum=0.9),
multi_transform(lr_backbone_schedule)
).init(transfer_model.parameters())