Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

EarlyStopping override disrupts wandb logging frequency

RafiBrent opened this issue · comments

Bug description

When an EarlyStopping callback would halt the training before min_epochs has elapsed, EarlyStopping is (correctly) overridden, and prints the warning message given below. However, at the exact step number when the warning was printed, WandbLogger suddenly begins logging the train metrics for every single batch. This results in slowed training and strange output graphs.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import torch
from lightning.pytorch import LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    seed_everything(42, workers=True)
    wandb_logger = WandbLogger(project="bug-report",
                               entity="example-user", name="debug_logging")

    early_stopping_callback = EarlyStopping(monitor="train_loss", patience=2)

    callbacks = [early_stopping_callback]

    kwargs = {

        "log_every_n_steps": 8,
        "logger": wandb_logger,
        "num_sanity_val_steps": 0,
        "callbacks": callbacks,
        "val_check_interval": 0.1,
        "max_epochs": 10,
        "min_epochs": 2,
        "deterministic": True

    }

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(**kwargs)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


if __name__ == "__main__":
    run()

Error messages and logs

Epoch 0:  38%|████████████████████████████████████████▉                                                                  | 12/32 [00:00<00:00, 53.14it/s, v_num=yqz5]

Trainer was signaled to stop but the required `min_epochs=2` or `min_steps=None` has not been met. Training will continue...          

Environment

Current environment
  • CUDA:
    - GPU: None
    - available: False
    - version: None
  • Lightning:
    - lightning: 2.2.5
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.2
    - torch: 2.3.0
    - torchmetrics: 1.4.0.post0
  • Packages:
    - appdirs: 1.4.4
    - appnope: 0.1.4
    - asttokens: 2.4.1
    - brotli: 1.1.0
    - certifi: 2024.6.2
    - chardet: 5.2.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - colorama: 0.4.6
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - docker-pycreds: 0.4.0
    - exceptiongroup: 1.2.0
    - executing: 2.0.1
    - filelock: 3.14.0
    - fonttools: 4.53.0
    - freetype-py: 2.3.0
    - fsspec: 2024.6.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - gmpy2: 2.1.5
    - greenlet: 3.0.3
    - idna: 3.7
    - importlib-metadata: 7.1.0
    - ipykernel: 6.29.3
    - ipython: 8.25.0
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - jupyter-client: 8.6.2
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.5
    - lightning: 2.2.5
    - lightning-utilities: 0.11.2
    - markupsafe: 2.1.5
    - matplotlib: 3.8.4
    - matplotlib-inline: 0.1.7
    - mpmath: 1.3.0
    - munkres: 1.1.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - numexpr: 2.10.0
    - numpy: 1.26.4
    - packaging: 24.0
    - pandas: 2.2.2
    - parso: 0.8.4
    - pathtools: 0.1.2
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 10.3.0
    - pip: 24.0
    - platformdirs: 4.2.2
    - prompt-toolkit: 3.0.46
    - protobuf: 4.25.3
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-cpuinfo: 9.0.0
    - pycairo: 1.26.0
    - pygments: 2.18.0
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - python-dateutil: 2.9.0
    - pytorch-lightning: 2.2.2
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - rdkit: 2024.3.3
    - reportlab: 4.1.0
    - requests: 2.32.3
    - rlpycairo: 0.2.0
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - sentry-sdk: 2.4.0
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - six: 1.16.0
    - smmap: 5.0.0
    - sqlalchemy: 2.0.30
    - stack-data: 0.6.2
    - sympy: 1.12
    - tables: 3.9.2
    - threadpoolctl: 3.5.0
    - torch: 2.3.0
    - torchmetrics: 1.4.0.post0
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - typing-extensions: 4.12.1
    - tzdata: 2024.1
    - urllib3: 2.2.1
    - wandb: 0.16.5
    - wcwidth: 0.2.13
    - wheel: 0.43.0
    - zipp: 3.17.0
  • System:
    - OS: Darwin
    - architecture:
    - 64bit
    -
    - processor: arm
    - python: 3.11.9
    - release: 23.5.0
    - version: Darwin Kernel Version 23.5.0: Wed May 1 20:19:05 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T8112

More info

The symptoms of this bug are somewhat similar to those of #16821 and #13525, but based on those threads it seems like the causes may be different.