[Bug]: AUROC miscalculation when using small validation batchsize
brein0453 opened this issue · comments
Describe the bug
I'm using anomalib v 1.0.1, and found AUROC value miscalculation when validation batch size is small.
I think torchmetric AUROC causes this error
https://lightning.ai/docs/torchmetrics/stable/classification/auroc.html
Dataset
Other (please specify in the text field below)
Model
Other (please specify in the field below)
Steps to reproduce the behavior
this is simple example.
batchsize=1 and scores in [-2, 1.8]
>>> from anomalib.metrics import AUROC
>>> import torch
>>> auroc = AUROC()
>>> for pred in zip(torch.arange(-2, 2, 0.2)):
... target = pred > 0
... auroc.update(pred, target)
>>> auroc.compute()
tensor(0.8990)
>>> from sklearn.metrics import roc_auc_score
>>> preds = []
>>> targets = []
>>> for pred in torch.arange(-2, 2, 0.2):
... preds.append(pred)
... targets.append(pred > 0)
>>> roc_auc_score(targets, preds)
>>> 1.0
OS information
.
Expected behavior
.
Screenshots
No response
Pip/GitHub
pip
What version/branch did you use?
1.0.1
Configuration YAML
.
Logs
.
Code of Conduct
- I agree to follow this project's Code of Conduct
This seems quite problematic, but I am not sure what would be the reason for this behavior. Anomalib AUROC doesn't use torchmetrics AUROC, but rather BinaryROC and auc combo source.
Binary ROC applies sigmoid if input range not in [0, 1]
this is update() function of Binary ROC (https://github.com/Lightning-AI/torchmetrics/blob/be3970810a38753bfe53b8b162743de63bd6b179/src/torchmetrics/classification/precision_recall_curve.py)
def update(self, preds: Tensor, target: Tensor) -> None:
"""Update metric states."""
if self.validate_args:
_binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
preds, target, _ = _binary_precision_recall_curve_format(preds, target, self.thresholds, self.ignore_index)
state = _binary_precision_recall_curve_update(preds, target, self.thresholds)
if isinstance(state, Tensor):
self.confmat += state
else:
self.preds.append(state[0])
self.target.append(state[1])
and _binary_precision_recall_curve_format is (https://github.com/Lightning-AI/torchmetrics/blob/be3970810a38753bfe53b8b162743de63bd6b179/src/torchmetrics/functional/classification/precision_recall_curve.py):
def _binary_precision_recall_curve_format(
preds: Tensor,
target: Tensor,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
preds = preds.flatten()
target = target.flatten()
if ignore_index is not None:
idx = target != ignore_index
preds = preds[idx]
target = target[idx]
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds
I think this is the reason
But sklearn function does that as well, doesn't it?
sklearn.metrics.roc_auc_score does not apply sigmoid at preds.
I fixed this by appling sigmoid at my model's score. I think it would be better to deal with this issue in torchmetric. Thank you for your response.
I actually tested your code and after some debugging I found the reason. Problem here is that if you update values one by one, the sigmoid is only applied if value is <0 or >1. This assumption is broken in your example when the range is from 0 to 1, so the sigmoid is not used at all. (for example. 0.1 stays 0.1, but with sigmoid that is 0.525). This is a bit of a weird case, but I think it's the expected behavior. To get the same results, you can apply the same storing procedure as in case of sklearn:
from anomalib.metrics import AUROC
import torch
auroc = AUROC()
preds = []
targets = []
for pred in torch.arange(-2, 2, 0.21):
preds.append(pred)
targets.append(int(pred > 0))
auroc.update(torch.tensor(preds), torch.tensor(targets))
auroc.compute()
This now gives the same result, as the check for all >= 0 and <= 1 fails so sigmoid is applied to all elements in this case, resulting in the desired behavior.
Another correct solution for small batches would thus be a manual usage of sigmoid before the update call.