Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Use lr setter callback instead of `attr_name` in `LearningRateFinder` and `Tuner`

arthurdjn opened this issue · comments

Description & Motivation

I would like to change the the Tuner and LearningRateFinder API so that it is possible to use more custom models.

Description

Currently, the learning rate can only be accessed from the LightningModule model through an attribute name (either as attribute or within the hyper parameters hparams). This can be configured through the attr_name parameter.

However, I would like to replace the attr_name parameter with a callback lr_setter to allow advanced access, customization and freedom on where the learning rate is located inside the model.

Motivation

While the current implementation will suit most use cases, it does not fit some advanced usage. Let's say I want to provide the learning rater through a partially instantiated optimizer. This works very well in a hydra / conf setup.

For example, I find the use of partially instantiated optimizer really helpful for tracking experiments, etc. I often use something like:

import functools
from typing import Callable

from torch.optim import Adam, Optimizer
from lightning.pytorch import LightningModule


PartialOptimizer = Callable[..., Optimizer]

class LitModel(LightningModule):
    def __init__(self, optimizer: PartialOptimizer) -> None:
        super().__init__()
        self.save_hyperparameters()

        self.model = Model(...)

    def configure_optimizer(self) -> Optimizer:
        optimizer = self.hparams.get("optimizer")
        return optimizer(params=self.parameters())


optimizer = functools.partial(Adam, lr=0.001)
model = LitModel(optimizer)

With this implementation it is not possible to use the LearningRateFinder callback or Tuner because the learning rate is not accessible through an attribute or the hyper parameters.

Pitch

I would like to change the parameter attr_name to lr_setter (or similar), which could be a function that sets the learning rate given the model. With that functionality it could be possible to use the Tuner and LearningRateFinder in more advanced cases, while being compatible with attr-defined learning rate:

The type of the lr_setter could be a Callable[[pl.LightningModule, float], None].

Use case: attr-defined

This is the case that is described currently in the docs

# Using https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html#using-lightning-s-built-in-lr-finder
model = LitModel(learning_rate=0.001)

trainer = Trainer()
tuner = Tuner(trainer)


# Before
lr_finder = tuner.lr_finder(model, attr_name="learning_rate")  # Note: attr_name is optional here

# After
from lightning.pytorch.utilities.parsing import lightning_setattr

lr_finder = tuner.lr_finder(model, lr_setter=lambda model, lr: lightning_setattr(model, "learning_rate", lr))

Use case: partial optimizer

This is the case that is not compatible with the current API.

# Using the custom LitModel with partial optimizer instead of attr defined learning rate
optimizer = functools.partial(Adam, lr=0.001)
model = LitModel(optimizer)

trainer = Trainer()
tuner = Tuner(trainer)


# Before
lr_finder = tuner.lr_finder(model, attr_name=???) # Not possible to access the learning rate

# After
def partial_setattr(fn: functools.partial, key: str, value: float) -> None:
    *_, (f, args, kwargs, n) = fn.__reduce__()
    kwargs[key] = value
    fn.__setstate__((f, args, kwargs, n))

lr_finder = tuner.lr_finder(model, lr_setter=lambda model, lr: partial_setattr(model.hparams["optimizer"], "lr", lr)

With this implementation it is not possible to use the LearningRateFinder callback or Tuner because the learning rate is not accessible through an attribute or the hyper parameters.

Alternatives

I already implemented a version that achieves exactly that. The update is minimal and core changes will be in the _lr_find function, from lightning/pytorch/tuner/lr_finder.py module.

There are only two places to update to make this work:

  1. Remove / adapt the lines that automatically find the attr_name, since this new feature will use a lr_setter function. We could add something similar that automatically generates a lr_setter if none is provided: first check if a lr or learning_rate attr is defined and create associated setter, then check if there is a partial optimizer defined named optim or optimizer and adapt the lr_setter. Maybe not necessary, or force the user to specify the setter function.
  2. Call the lr_setter instead of using the lightning_setattr function.

Also this feature requires to rename the attr_name to lr_setter to make it obvious that the parameter is a setter.

Additional context

This feature will change the API and is not backward compatible, if the name attr_name is changed. However, the capacities remain the same, but offer more customization.

I already have a working implementation of this feature, compatible with the latest version of lightning. If this is of interest, I can submit a PR.
Thanks, really loving this library!

cc @Borda

@arthurdjn Thanks for the feature request.
Currently I don't understand the use case 100%. My question is, since the contract between Trainer and LM is that the configure_optimizers() hook returns the optimizer, what would that implementation look like in your case? This is where you'd normally configure the learning rate. In other words, why couldn't that special callable you have be applied there?

The attribute the tuner saves the new learning rate to is only a temporary holder in that sense. Ultimately, it needs to be set in configure_optimizers().

The use case is that I don't want to pass the learning rate as an attribute to the lightning module, instead I want to pass the partially instantiated optimizer. With this approach, it is not possible to use the Tuner or LearningRateFinder.

# Using the custom LitModel with partial optimizer instead of attr defined learning rate
optimizer = functools.partial(Adam, lr=0.001)
model = LitModel(optimizer)

trainer = Trainer()
tuner = Tuner(trainer)


# This does not work
lr_finder = tuner.lr_finder(model, attr_name=???) # Not possible to access the learning rate

Maybe I am missing something, but I think this use case is a current limitation of the API thus this feature proposal.

As a callback you proposed this function that updates the lr argument in the partial wrapper:

def partial_setattr(fn: functools.partial, key: str, value: float) -> None:
    *_, (f, args, kwargs, n) = fn.__reduce__()
    kwargs[key] = value
    fn.__setstate__((f, args, kwargs, n))

My question is why couldn't you define your configure_optimizers() this way (pseudo code):

class LitModel:
    def __init__(self, optimizer_cls):
        self.optimizer_cls = optimizer_cls
        self.learning_rate = None  # or a default value

    def configure_optimizers(self):
        if self.learning_rate is not None:
            partial_setattr(self.optimizer_cls, "lr", self.learning_rate)
        return self.optimizer_cls(self.parameters())

Well, I agree that your solution is pretty simple and clean. I didn't want to have an extra learning_rate attribute in the model to avoid having the same value in different places. I thought that providing the setter function directly to the Tuner might be a clearer alternative. However, I think it's a good alternative.

What do you think?