Show how to over-fit batches for real
svnv-svsv-jm opened this issue · comments
Description & Motivation
Ok, there is a flag for the Trainer
. But how to programmatically check that the loss goes down?
I'd like to write a test, checking that the loss at the end is lower than at the start.
Pitch
import lightning.pytorch as pl
trainer = pl.Trainer(max_steps=64, overfit_batches=1)
(loss_start, loss_end) = trainer.fit(model, datamodule)
assert loss_end < loss_start
Alternatives
Show how to get those values from the Trainer
.
Additional context
No response
cc @Borda
I think you can call the trainer class property 'logged_metrics', which calls the train loss
for example:
# Retrieve logged losses
train_losses = trainer.logged_metrics['train_loss']
# Check if the final loss is lower than the initial loss
initial_loss = train_losses[0]
final_loss = train_losses[-1]
assert final_loss < initial_loss,