ModelCheckpoint: Using save_top_k, only the first k models are stored, not the best k models
gboeer opened this issue · comments
Bug description
From the documentation, I got the impression that using the save_top_k
argument of the ModelCheckpoint
callback, would lead to the best k
models being stored, according to the monitored value. However, in my experiments only the first 3 models (from epoch 0, 1, 2) are stored and nothing afterward. I made sure that the monitored value indeed is higher for later epochs, which I can see clearly from the logged metrics.csv.
So either way this is a bug or I simply misunderstood the meaning of this parameter.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(save_top_k=3, monitor="val_accuracy")
trainer = L.Trainer(accelerator='gpu', devices=[0], log_every_n_steps=10, callbacks=[checkpoint_callback])
# val function of my LightningModule
def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
loss = self.val_criterion(outputs, labels)
_, predictions = torch.max(outputs, 1)
val_accuracy = torch.sum(predictions == labels.data).double() / labels.size(0)
self.log('val_loss', loss)
self.log('val_accuracy', val_accuracy)
return loss
Error messages and logs
No response
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- available: True
- version: 11.7 - Lightning:
- lightning: 2.2.4
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.2.4
- torch: 2.0.1
- torchmetrics: 1.3.2
- torchvision: 0.15.2 - Packages:
- aiofiles: 23.2.1
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- alembic: 1.13.1
- aniso8601: 9.0.1
- annotated-types: 0.6.0
- anyio: 4.3.0
- argcomplete: 3.3.0
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 23.2.0
- azure-core: 1.30.1
- azure-identity: 1.16.0
- azure-storage-blob: 12.19.1
- backcall: 0.2.0
- backoff: 2.2.1
- backports.zoneinfo: 0.2.1
- bcrypt: 4.1.2
- beautifulsoup4: 4.12.3
- blinker: 1.8.1
- boto3: 1.34.92
- botocore: 1.34.92
- bracex: 2.4
- brotli: 1.1.0
- cachetools: 5.3.3
- certifi: 2024.2.2
- cffi: 1.16.0
- cfgv: 3.4.0
- chardet: 5.2.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- cloudpickle: 3.0.0
- cmake: 3.29.2
- colorama: 0.4.6
- contourpy: 1.1.1
- coverage: 7.5.0
- cryptography: 42.0.5
- cycler: 0.12.1
- dacite: 1.7.0
- dbus-python: 1.2.16
- decorator: 5.1.1
- deprecated: 1.2.14
- dill: 0.3.8
- distlib: 0.3.8
- dnspython: 2.6.1
- docker: 7.0.0
- entrypoints: 0.4
- exceptiongroup: 1.2.1
- executing: 2.0.1
- fiftyone: 0.15.7
- fiftyone-brain: 0.16.1
- fiftyone-db: 0.4.0
- filelock: 3.13.4
- flask: 3.0.3
- fonttools: 4.51.0
- frozenlist: 1.4.1
- fsspec: 2024.3.1
- ftfy: 6.2.0
- future: 1.0.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- glob2: 0.7
- google-api-core: 2.18.0
- google-api-python-client: 2.127.0
- google-auth: 2.29.0
- google-auth-httplib2: 0.2.0
- google-cloud-core: 2.4.1
- google-cloud-storage: 2.16.0
- google-crc32c: 1.5.0
- google-resumable-media: 2.7.0
- googleapis-common-protos: 1.63.0
- graphene: 3.3
- graphql-core: 3.2.3
- graphql-relay: 3.2.0
- greenlet: 3.0.3
- gunicorn: 21.2.0
- h11: 0.14.0
- h2: 4.1.0
- hpack: 4.0.0
- httpcore: 1.0.5
- httplib2: 0.22.0
- httpx: 0.27.0
- humanize: 4.9.0
- hypercorn: 0.16.0
- hyperframe: 6.0.1
- identify: 2.5.36
- idna: 3.7
- imageio: 2.34.1
- imgaug: 0.4.0
- importlib-metadata: 7.1.0
- importlib-resources: 6.4.0
- inflate64: 1.0.0
- iniconfig: 2.0.0
- ipython: 8.12.3
- isodate: 0.6.1
- itsdangerous: 2.2.0
- jedi: 0.19.1
- jinja2: 3.1.3
- jmespath: 1.0.1
- joblib: 1.4.0
- jsonlines: 4.0.0
- kaleido: 0.2.1
- kiwisolver: 1.4.5
- lazy-loader: 0.4
- lightning: 2.2.4
- lightning-utilities: 0.11.2
- lit: 18.1.4
- mako: 1.3.3
- markdown: 3.6
- markupsafe: 2.1.5
- matplotlib: 3.7.5
- matplotlib-inline: 0.1.7
- mlflow: 2.12.1
- mongoengine: 0.24.2
- motor: 3.4.0
- mpmath: 1.3.0
- msal: 1.28.0
- msal-extensions: 1.1.0
- multidict: 6.0.5
- multivolumefile: 0.2.3
- mypy: 1.10.0
- mypy-extensions: 1.0.0
- networkx: 3.1
- nodeenv: 1.8.0
- numpy: 1.24.4
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- opencv-python: 4.9.0.80
- opencv-python-headless: 4.9.0.80
- packaging: 24.0
- pandas: 2.0.3
- paramiko: 3.4.0
- parso: 0.8.4
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 10.3.0
- pip: 24.0
- pkg-resources: 0.0.0
- platformdirs: 4.2.1
- plotly: 5.21.0
- pluggy: 1.5.0
- portalocker: 2.8.2
- pprintpp: 0.4.0
- pre-commit: 3.5.0
- priority: 2.0.0
- prompt-toolkit: 3.0.43
- proto-plus: 1.23.0
- protobuf: 4.25.3
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py7zr: 0.21.0
- pyarrow: 15.0.2
- pyasn1: 0.6.0
- pyasn1-modules: 0.4.0
- pybcj: 1.0.2
- pycparser: 2.22
- pycryptodomex: 3.20.0
- pydantic: 2.7.1
- pydantic-core: 2.18.2
- pygments: 2.17.2
- pygobject: 3.36.0
- pyjwt: 2.8.0
- pymongo: 4.7.0
- pynacl: 1.5.0
- pyparsing: 3.1.2
- pyppmd: 1.1.0
- pyproject-api: 1.6.1
- pysftp: 0.2.9
- pytest: 8.1.1
- pytest-cov: 5.0.0
- pytest-mock: 3.14.0
- python-dateutil: 2.9.0.post0
- pytorch-lightning: 2.2.4
- pytz: 2024.1
- pywavelets: 1.4.1
- pyyaml: 6.0.1
- pyzstd: 0.15.10
- querystring-parser: 1.2.4
- rarfile: 4.2
- regex: 2024.4.16
- requests: 2.31.0
- retrying: 1.3.4
- rsa: 4.9
- ruff: 0.4.2
- s3transfer: 0.10.1
- schedule: 1.2.1
- scikit-image: 0.21.0
- scikit-learn: 1.3.2
- scipy: 1.10.1
- setuptools: 44.0.0
- shapely: 2.0.4
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- sortedcontainers: 2.4.0
- soupsieve: 2.5
- sqlalchemy: 2.0.29
- sqlparse: 0.5.0
- sse-starlette: 0.10.3
- sseclient-py: 1.8.0
- stack-data: 0.6.3
- starlette: 0.37.2
- strawberry-graphql: 0.138.1
- sympy: 1.12
- tabulate: 0.9.0
- taskgroup: 0.0.0a4
- tenacity: 8.2.3
- texttable: 1.7.0
- threadpoolctl: 3.4.0
- tifffile: 2023.7.10
- tomli: 2.0.1
- torch: 2.0.1
- torchmetrics: 1.3.2
- torchvision: 0.15.2
- tox: 4.14.2
- tqdm: 4.66.2
- traitlets: 5.14.3
- triton: 2.0.0
- typing-extensions: 4.11.0
- tzdata: 2024.1
- tzlocal: 5.2
- universal-analytics-python3: 1.1.1
- uritemplate: 4.1.1
- urllib3: 1.26.18
- virtualenv: 20.26.0
- voxel51-eta: 0.12.6
- wcmatch: 8.5.1
- wcwidth: 0.2.13
- websocket-client: 1.4.2
- werkzeug: 3.0.2
- wheel: 0.43.0
- wrapt: 1.16.0
- wsproto: 1.2.0
- xmltodict: 0.13.0
- yacs: 0.1.8
- yarl: 1.9.4
- zipp: 3.18.1 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- release: 4.15.0-213-generic
- version: #224-Ubuntu SMP Mon Jun 19 13:30:12 UTC 2023
More info
No response
Seems like maybe you need to set ModelCheckpoint(save_top_k=3, monitor="val_accuracy", mode="max")
so that it will save the checkpoint with the highest accuracy. The default value is mode="min"
to save the minimum loss.
Seems like maybe you need to set
ModelCheckpoint(save_top_k=3, monitor="val_accuracy", mode="max")
so that it will save the checkpoint with the highest accuracy. The default value ismode="min"
to save the minimum loss.
Ohh, you're absolutely right. Guess the default is meant for using the loss. Stupid me.