Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ModelCheckpoint: Using save_top_k, only the first k models are stored, not the best k models

gboeer opened this issue · comments

Bug description

From the documentation, I got the impression that using the save_top_k argument of the ModelCheckpoint callback, would lead to the best k models being stored, according to the monitored value. However, in my experiments only the first 3 models (from epoch 0, 1, 2) are stored and nothing afterward. I made sure that the monitored value indeed is higher for later epochs, which I can see clearly from the logged metrics.csv.

So either way this is a bug or I simply misunderstood the meaning of this parameter.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(save_top_k=3, monitor="val_accuracy")
trainer = L.Trainer(accelerator='gpu', devices=[0], log_every_n_steps=10, callbacks=[checkpoint_callback])

# val function of my LightningModule
def validation_step(self, batch, batch_idx):
    inputs, labels = batch        
    outputs = self.model(inputs)
    loss = self.val_criterion(outputs, labels)
    _, predictions = torch.max(outputs, 1)
    val_accuracy = torch.sum(predictions == labels.data).double() / labels.size(0)
    self.log('val_loss', loss)
    self.log('val_accuracy', val_accuracy)
    return loss

Error messages and logs

No response

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - available: True
    - version: 11.7
  • Lightning:
    - lightning: 2.2.4
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.4
    - torch: 2.0.1
    - torchmetrics: 1.3.2
    - torchvision: 0.15.2
  • Packages:
    - aiofiles: 23.2.1
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - alembic: 1.13.1
    - aniso8601: 9.0.1
    - annotated-types: 0.6.0
    - anyio: 4.3.0
    - argcomplete: 3.3.0
    - asttokens: 2.4.1
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - azure-core: 1.30.1
    - azure-identity: 1.16.0
    - azure-storage-blob: 12.19.1
    - backcall: 0.2.0
    - backoff: 2.2.1
    - backports.zoneinfo: 0.2.1
    - bcrypt: 4.1.2
    - beautifulsoup4: 4.12.3
    - blinker: 1.8.1
    - boto3: 1.34.92
    - botocore: 1.34.92
    - bracex: 2.4
    - brotli: 1.1.0
    - cachetools: 5.3.3
    - certifi: 2024.2.2
    - cffi: 1.16.0
    - cfgv: 3.4.0
    - chardet: 5.2.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - cloudpickle: 3.0.0
    - cmake: 3.29.2
    - colorama: 0.4.6
    - contourpy: 1.1.1
    - coverage: 7.5.0
    - cryptography: 42.0.5
    - cycler: 0.12.1
    - dacite: 1.7.0
    - dbus-python: 1.2.16
    - decorator: 5.1.1
    - deprecated: 1.2.14
    - dill: 0.3.8
    - distlib: 0.3.8
    - dnspython: 2.6.1
    - docker: 7.0.0
    - entrypoints: 0.4
    - exceptiongroup: 1.2.1
    - executing: 2.0.1
    - fiftyone: 0.15.7
    - fiftyone-brain: 0.16.1
    - fiftyone-db: 0.4.0
    - filelock: 3.13.4
    - flask: 3.0.3
    - fonttools: 4.51.0
    - frozenlist: 1.4.1
    - fsspec: 2024.3.1
    - ftfy: 6.2.0
    - future: 1.0.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - glob2: 0.7
    - google-api-core: 2.18.0
    - google-api-python-client: 2.127.0
    - google-auth: 2.29.0
    - google-auth-httplib2: 0.2.0
    - google-cloud-core: 2.4.1
    - google-cloud-storage: 2.16.0
    - google-crc32c: 1.5.0
    - google-resumable-media: 2.7.0
    - googleapis-common-protos: 1.63.0
    - graphene: 3.3
    - graphql-core: 3.2.3
    - graphql-relay: 3.2.0
    - greenlet: 3.0.3
    - gunicorn: 21.2.0
    - h11: 0.14.0
    - h2: 4.1.0
    - hpack: 4.0.0
    - httpcore: 1.0.5
    - httplib2: 0.22.0
    - httpx: 0.27.0
    - humanize: 4.9.0
    - hypercorn: 0.16.0
    - hyperframe: 6.0.1
    - identify: 2.5.36
    - idna: 3.7
    - imageio: 2.34.1
    - imgaug: 0.4.0
    - importlib-metadata: 7.1.0
    - importlib-resources: 6.4.0
    - inflate64: 1.0.0
    - iniconfig: 2.0.0
    - ipython: 8.12.3
    - isodate: 0.6.1
    - itsdangerous: 2.2.0
    - jedi: 0.19.1
    - jinja2: 3.1.3
    - jmespath: 1.0.1
    - joblib: 1.4.0
    - jsonlines: 4.0.0
    - kaleido: 0.2.1
    - kiwisolver: 1.4.5
    - lazy-loader: 0.4
    - lightning: 2.2.4
    - lightning-utilities: 0.11.2
    - lit: 18.1.4
    - mako: 1.3.3
    - markdown: 3.6
    - markupsafe: 2.1.5
    - matplotlib: 3.7.5
    - matplotlib-inline: 0.1.7
    - mlflow: 2.12.1
    - mongoengine: 0.24.2
    - motor: 3.4.0
    - mpmath: 1.3.0
    - msal: 1.28.0
    - msal-extensions: 1.1.0
    - multidict: 6.0.5
    - multivolumefile: 0.2.3
    - mypy: 1.10.0
    - mypy-extensions: 1.0.0
    - networkx: 3.1
    - nodeenv: 1.8.0
    - numpy: 1.24.4
    - nvidia-cublas-cu11: 11.10.3.66
    - nvidia-cuda-cupti-cu11: 11.7.101
    - nvidia-cuda-nvrtc-cu11: 11.7.99
    - nvidia-cuda-runtime-cu11: 11.7.99
    - nvidia-cudnn-cu11: 8.5.0.96
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-curand-cu11: 10.2.10.91
    - nvidia-cusolver-cu11: 11.4.0.1
    - nvidia-cusparse-cu11: 11.7.4.91
    - nvidia-nccl-cu11: 2.14.3
    - nvidia-nvtx-cu11: 11.7.91
    - opencv-python: 4.9.0.80
    - opencv-python-headless: 4.9.0.80
    - packaging: 24.0
    - pandas: 2.0.3
    - paramiko: 3.4.0
    - parso: 0.8.4
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 10.3.0
    - pip: 24.0
    - pkg-resources: 0.0.0
    - platformdirs: 4.2.1
    - plotly: 5.21.0
    - pluggy: 1.5.0
    - portalocker: 2.8.2
    - pprintpp: 0.4.0
    - pre-commit: 3.5.0
    - priority: 2.0.0
    - prompt-toolkit: 3.0.43
    - proto-plus: 1.23.0
    - protobuf: 4.25.3
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py7zr: 0.21.0
    - pyarrow: 15.0.2
    - pyasn1: 0.6.0
    - pyasn1-modules: 0.4.0
    - pybcj: 1.0.2
    - pycparser: 2.22
    - pycryptodomex: 3.20.0
    - pydantic: 2.7.1
    - pydantic-core: 2.18.2
    - pygments: 2.17.2
    - pygobject: 3.36.0
    - pyjwt: 2.8.0
    - pymongo: 4.7.0
    - pynacl: 1.5.0
    - pyparsing: 3.1.2
    - pyppmd: 1.1.0
    - pyproject-api: 1.6.1
    - pysftp: 0.2.9
    - pytest: 8.1.1
    - pytest-cov: 5.0.0
    - pytest-mock: 3.14.0
    - python-dateutil: 2.9.0.post0
    - pytorch-lightning: 2.2.4
    - pytz: 2024.1
    - pywavelets: 1.4.1
    - pyyaml: 6.0.1
    - pyzstd: 0.15.10
    - querystring-parser: 1.2.4
    - rarfile: 4.2
    - regex: 2024.4.16
    - requests: 2.31.0
    - retrying: 1.3.4
    - rsa: 4.9
    - ruff: 0.4.2
    - s3transfer: 0.10.1
    - schedule: 1.2.1
    - scikit-image: 0.21.0
    - scikit-learn: 1.3.2
    - scipy: 1.10.1
    - setuptools: 44.0.0
    - shapely: 2.0.4
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - sortedcontainers: 2.4.0
    - soupsieve: 2.5
    - sqlalchemy: 2.0.29
    - sqlparse: 0.5.0
    - sse-starlette: 0.10.3
    - sseclient-py: 1.8.0
    - stack-data: 0.6.3
    - starlette: 0.37.2
    - strawberry-graphql: 0.138.1
    - sympy: 1.12
    - tabulate: 0.9.0
    - taskgroup: 0.0.0a4
    - tenacity: 8.2.3
    - texttable: 1.7.0
    - threadpoolctl: 3.4.0
    - tifffile: 2023.7.10
    - tomli: 2.0.1
    - torch: 2.0.1
    - torchmetrics: 1.3.2
    - torchvision: 0.15.2
    - tox: 4.14.2
    - tqdm: 4.66.2
    - traitlets: 5.14.3
    - triton: 2.0.0
    - typing-extensions: 4.11.0
    - tzdata: 2024.1
    - tzlocal: 5.2
    - universal-analytics-python3: 1.1.1
    - uritemplate: 4.1.1
    - urllib3: 1.26.18
    - virtualenv: 20.26.0
    - voxel51-eta: 0.12.6
    - wcmatch: 8.5.1
    - wcwidth: 0.2.13
    - websocket-client: 1.4.2
    - werkzeug: 3.0.2
    - wheel: 0.43.0
    - wrapt: 1.16.0
    - wsproto: 1.2.0
    - xmltodict: 0.13.0
    - yacs: 0.1.8
    - yarl: 1.9.4
    - zipp: 3.18.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.10
    - release: 4.15.0-213-generic
    - version: #224-Ubuntu SMP Mon Jun 19 13:30:12 UTC 2023

More info

No response

Seems like maybe you need to set ModelCheckpoint(save_top_k=3, monitor="val_accuracy", mode="max") so that it will save the checkpoint with the highest accuracy. The default value is mode="min" to save the minimum loss.

Seems like maybe you need to set ModelCheckpoint(save_top_k=3, monitor="val_accuracy", mode="max") so that it will save the checkpoint with the highest accuracy. The default value is mode="min" to save the minimum loss.

Ohh, you're absolutely right. Guess the default is meant for using the loss. Stupid me.