Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Show how to over-fit batches for real

svnv-svsv-jm opened this issue · comments

Description & Motivation

Ok, there is a flag for the Trainer. But how to programmatically check that the loss goes down?

I'd like to write a test, checking that the loss at the end is lower than at the start.

Pitch

import lightning.pytorch as pl

trainer = pl.Trainer(max_steps=64, overfit_batches=1)
(loss_start, loss_end) = trainer.fit(model, datamodule)

assert loss_end < loss_start

Alternatives

Show how to get those values from the Trainer.

Additional context

No response

cc @Borda

I think you can call the trainer class property 'logged_metrics', which calls the train loss

for example:

# Retrieve logged losses
    train_losses = trainer.logged_metrics['train_loss']
    
    # Check if the final loss is lower than the initial loss
    initial_loss = train_losses[0]
    final_loss = train_losses[-1]
    
    assert final_loss < initial_loss,