Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Differentiate testing multiple sets/models when logging

leleogere opened this issue · comments

Description & Motivation

In my problem, I need to evaluate my trained model twice, on two different sets at the end of my training:

trainer.test(model, dataloaders=test_dataloader1)
trainer.test(model, dataloaders=test_dataloader2)

However, both scores are logged with the same key (I'm using wandb logger), meaning that they are merged into a single metric. I can always get the two values separately using their API, but in their UI, it's not easy (if even possible) to see and compare them.

This is also a problem when trying to evaluate two different checkpoints:

trainer.test(model, dataloaders=test_dataloader, ckpt_path="last")
trainer.test(model, dataloaders=test_dataloader, ckpt_path="best")

Pitch

Ideally, it would be handy to allow Trainer.test (and maybe the other fit, validate and predict) to take kwargs arguments, that would be directly passed to LightningModule.test_step and LightningModule.on_test_epoch_end.

This would allow letting the user managing the logging process depending on its own arguments:

# Training script
trainer.test(model, dataloaders=test_dataloader1, name="test1")
trainer.test(model, dataloaders=test_dataloader2, name="test2")
# LightningModule
def test_step(self, self, batch, batch_idx, name = "test")
    y_pred = self.forward(batch["x"])
    y_true = batch["y"]
    acc = self.accuracy(y_true, y_pred)
    self.logger.log(f"{name}/acc", acc)

This would result in score being logger to test1/acc and test2/acc, making it easy to differentiate them in the wandb UI and the logs.

Alternatives

For the case of multiple test sets, one could first merge them and passing them as one unique dataloader. However, this prevents comparing the performance on each individual dataset.

Additional context

No response

cc @Borda