ashleve / lightning-hydra-template

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Using OneCycleLR scheduler

phum1901 opened this issue · comments

when using the OneCycleLR scheduler, it needs to provide total_steps and there's a useful property in the Trainer that will calculate the total steps for u like the code below

def configure_optimizers(self):
    optimizer = ...
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
    )
    ...

but when it come to hydra config what if i want to change the lr scheduler from this (mnist's template)

scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.1
  patience: 10

to this

scheduler:
  _target_: torch.optim.lr_scheduler.OneCycleLR
  _partial_: true
  max_lr: 1e-2
  total_steps: ???

with out touching this

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
        """
        optimizer = self.hparams.optimizer(params=self.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

how can i do that?

The simplest way would be changing

scheduler = self.hparams.scheduler(optimizer=optimizer)

to:

scheduler = self.hparams.scheduler(optimizer=optimizer, total_steps=self.trainer.estimated_stepping_batches)