[BUG] model checkpointing isn't saving weights
jeremyjordan opened this issue · comments
Jeremy Jordan commented
Describe the bug
We get the following error at the end of our first epoch.
Traceback (most recent call last):
File "flower_classifier/models/train.py", line 67, in <module>
trainer.fit(model, datamodule=data_module)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
result = fn(self, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 1073, in fit
results = self.accelerator_backend.train(model)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/gpu_backend.py", line 51, in train
results = self.trainer.run_pretrain_routine(model)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 1239, in run_pretrain_routine
self.train()
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py", line 394, in train
self.run_training_epoch()
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py", line 516, in run_training_epoch
self.run_evaluation(test_mode=False)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/evaluation_loop.py", line 603, in run_evaluation
self.on_validation_end()
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/callback_hook.py", line 176, in on_validation_end
callback.on_validation_end(self, self.get_model())
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py", line 27, in wrapped_fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 380, in on_validation_end
self._do_check_save(filepath, current, epoch, trainer, pl_module)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 421, in _do_check_save
self._save_model(filepath, trainer, pl_module)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 212, in _save_model
raise ValueError(".save_function() not set")
ValueError: .save_function() not set
This is because Pytorch Lightning is doing some extra processing for the checkpoint_callback
arg that it doesn't do for normal callbacks.
https://github.com/PyTorchLightning/pytorch-lightning/blob/ff0064f9563bcbbd2e3606ffb99ce8ba85a2791b/pytorch_lightning/trainer/connectors/callback_connector.py#L67
To Reproduce
Steps to reproduce the behavior:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
model = LightningModel()
checkpoint_callback = ModelCheckpoint(save_top_k=3)
trainer = Trainer(
gpus=1, callbacks=[checkpoint_callback], overfit_pct=0.01
)
trainer.fit(model)