Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.

Home Page:https://lightning.ai/docs/torchmetrics/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Wrong aggregation of Precision\Recall\F1-Score

silvanobianchi opened this issue Β· comments

πŸ› Bug

I am training a binary classifier and am monitoring the following metrics: precision, recall, accuracy, F1 score.

During validation phase, the dataloader loads N batches where the samples all belong to class 0, and then M batches where the samples belong to class 1.

The aggregate accuracy that is calculated at the end of the epoch is identical to what I get from my script, while Precision and Recall are wrong.

This behaviour occurs for Precision, Recall, F1-Score, which have about half the value they should have: Accuracy 0.9 while these values have 0.55
I think the problem is due to the fact that the first N batches have Precision-Recall around 0 because they have no positive class samples, while the other M batches have around 1 because they only have positive class samples and when I aggregate these results I get a value of around 0.5.

I am monitoring these values with BinaryPrecision, BinaryRecall, BinaryF1Score.

Do you know what could be the problem? I think there is something strange in the aggregation of these metrics

Hi! thanks for your contribution!, great first issue!

Hi @silvanobianchi , thanks for reporting this issue.
This sound very strange to me. Regardless of how data is feed into the metrics, the correct result should still be calculated. Here is an example where all data is feed in as one batch vs. doing it over multiple batches. Also note that the labels are split in the way you describe:

import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryPrecision, BinaryRecall, BinaryAccuracy

preds = torch.randint(2, (1000,))
target = torch.cat([torch.zeros(500), torch.ones(500)], dim=0)

# all as one batch
metrics = MetricCollection(BinaryAccuracy(), BinaryPrecision(), BinaryRecall())
print(metrics(preds, target))

# as separate batches
metrics = MetricCollection(BinaryAccuracy(), BinaryPrecision(), BinaryRecall())
for p, t in zip(preds.split(10), target.split(10)):
    metrics.update(p, t)
print(metrics.compute())

if you run the above it will output the exact same in both cases.