jeremyjordan / flower-classifier

A simple image classifier for flowers.

Home Page:https://share.streamlit.io/jeremyjordan/flower-classifier/app.py

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] model checkpointing isn't saving weights

jeremyjordan opened this issue · comments

Describe the bug
We get the following error at the end of our first epoch.

Traceback (most recent call last):
  File "flower_classifier/models/train.py", line 67, in <module>
    trainer.fit(model, datamodule=data_module)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 1073, in fit
    results = self.accelerator_backend.train(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/gpu_backend.py", line 51, in train
    results = self.trainer.run_pretrain_routine(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 1239, in run_pretrain_routine
    self.train()
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py", line 394, in train
    self.run_training_epoch()
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py", line 516, in run_training_epoch
    self.run_evaluation(test_mode=False)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/evaluation_loop.py", line 603, in run_evaluation
    self.on_validation_end()
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/callback_hook.py", line 176, in on_validation_end
    callback.on_validation_end(self, self.get_model())
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py", line 27, in wrapped_fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 380, in on_validation_end
    self._do_check_save(filepath, current, epoch, trainer, pl_module)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 421, in _do_check_save
    self._save_model(filepath, trainer, pl_module)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 212, in _save_model
    raise ValueError(".save_function() not set")
ValueError: .save_function() not set

This is because Pytorch Lightning is doing some extra processing for the checkpoint_callback arg that it doesn't do for normal callbacks.
https://github.com/PyTorchLightning/pytorch-lightning/blob/ff0064f9563bcbbd2e3606ffb99ce8ba85a2791b/pytorch_lightning/trainer/connectors/callback_connector.py#L67

To Reproduce
Steps to reproduce the behavior:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


model = LightningModel()
checkpoint_callback = ModelCheckpoint(save_top_k=3)
trainer = Trainer(
    gpus=1, callbacks=[checkpoint_callback], overfit_pct=0.01
)
trainer.fit(model)