graviraja / MLOps-Basics

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Different module metrics for train/val

rohitgr7 opened this issue · comments

Module metrics stores internal states computed over each call on different batches. So using the same instance for both train and val might not lead to correct results when computed over epoch with (on_epoch=True) in step hooks. I'd suggest creating separate ones for each task (train & val).

ref: https://torchmetrics.readthedocs.io/en/latest/pages/quickstart.html#module-metrics

self.accuracy_metric = torchmetrics.Accuracy()
self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)
self.precision_macro_metric = torchmetrics.Precision(
average="macro", num_classes=self.num_classes
)
self.recall_macro_metric = torchmetrics.Recall(
average="macro", num_classes=self.num_classes
)
self.precision_micro_metric = torchmetrics.Precision(average="micro")
self.recall_micro_metric = torchmetrics.Recall(average="micro")

@rohitgr7 fixed it. Thanks for pointing it out :)