Wrong aggregation of Precision\Recall\F1-Score
silvanobianchi opened this issue Β· comments
π Bug
I am training a binary classifier and am monitoring the following metrics: precision, recall, accuracy, F1 score.
During validation phase, the dataloader loads N batches where the samples all belong to class 0, and then M batches where the samples belong to class 1.
The aggregate accuracy that is calculated at the end of the epoch is identical to what I get from my script, while Precision and Recall are wrong.
This behaviour occurs for Precision, Recall, F1-Score, which have about half the value they should have: Accuracy 0.9 while these values have 0.55
I think the problem is due to the fact that the first N batches have Precision-Recall around 0 because they have no positive class samples, while the other M batches have around 1 because they only have positive class samples and when I aggregate these results I get a value of around 0.5.
I am monitoring these values with BinaryPrecision, BinaryRecall, BinaryF1Score.
Do you know what could be the problem? I think there is something strange in the aggregation of these metrics
Hi! thanks for your contribution!, great first issue!
Hi @silvanobianchi , thanks for reporting this issue.
This sound very strange to me. Regardless of how data is feed into the metrics, the correct result should still be calculated. Here is an example where all data is feed in as one batch vs. doing it over multiple batches. Also note that the labels are split in the way you describe:
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryPrecision, BinaryRecall, BinaryAccuracy
preds = torch.randint(2, (1000,))
target = torch.cat([torch.zeros(500), torch.ones(500)], dim=0)
# all as one batch
metrics = MetricCollection(BinaryAccuracy(), BinaryPrecision(), BinaryRecall())
print(metrics(preds, target))
# as separate batches
metrics = MetricCollection(BinaryAccuracy(), BinaryPrecision(), BinaryRecall())
for p, t in zip(preds.split(10), target.split(10)):
metrics.update(p, t)
print(metrics.compute())
if you run the above it will output the exact same in both cases.