`MetricCollection` did not copy inner state of metric in `ClasswiseWrapper` when computing groups metrics
daniel-code opened this issue Β· comments
π Bug
MetricCollection
did not copy inner state of metric in ClasswiseWrapper
when computing groups metrics.
To Reproduce
Steps to reproduce the behavior...
Code sample
import torch
from lightning import seed_everything
from torchmetrics import MetricCollection, ClasswiseWrapper
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
seed_everything(42)
random_pred = torch.rand((10, 3))
pred = torch.softmax(random_pred, dim=-1)
pred_class = torch.argmax(pred, dim=-1)
target = torch.randint(0, 3, size=(10,))
multiclass_acc = MulticlassAccuracy(
num_classes=3,
average=None,
)
print("multiclass_acc:", multiclass_acc(pred, target))
multiclass_f1 = MulticlassF1Score(
num_classes=3,
average=None,
)
print("multiclass_f1:", multiclass_f1(pred, target))
mc = MetricCollection(
{
"accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
"f1": ClasswiseWrapper(MulticlassF1Score(num_classes=3, average=None)),
},
compute_groups=[
["accuracy", "f1"],
],
)
print("MetricCollection.forward:", mc(pred, target))
mc.reset()
mc.update(pred, target)
print("MetricCollection.update&compute:", mc.compute())
Output
Seed set to 42
site-packages\torchmetrics\utilities\prints.py:43: UserWarning: The ``compute`` method of metric MulticlassF1Score was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
warnings.warn(*args, **kwargs) # noqa: B028
multiclass_acc: tensor([0.3333, 0.5000, 0.3333])
multiclass_f1: tensor([0.4000, 0.4444, 0.3333])
MetricCollection.forward: {'multiclassaccuracy_0': tensor(0.3333), 'multiclassaccuracy_1': tensor(0.5000), 'multiclassaccuracy_2': tensor(0.3333), 'multiclassf1score_0': tensor(0.4000), 'multiclassf1score_1': tensor(0.4444), 'multiclassf1score_2': tensor(0.3333)}
MetricCollection.update&compute: {'multiclassaccuracy_0': tensor(0.3333), 'multiclassaccuracy_1': tensor(0.5000), 'multiclassaccuracy_2': tensor(0.3333), 'multiclassf1score_0': tensor(0.), 'multiclassf1score_1': tensor(0.), 'multiclassf1score_2': tensor(0.)}
Expected behavior
The metrics multiclassf1score_0
, multiclassf1score_1
, and multiclassf1score_2
of MetricCollection.compute
should be the same as MetricCollection.forward
or the result of each metric.
Solution 1
https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/collections.py#L305
class MetricCollection:
def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
"""Create reference between metrics in the same compute group.
Args:
copy: If `True` the metric state will between members will be copied instead
of just passed by reference
"""
if not self._state_is_copy:
for cg in self._groups.values():
m0 = getattr(self, cg[0])
for i in range(1, len(cg)):
mi = getattr(self, cg[i])
for state in m0._defaults:
m0_state = getattr(m0, state)
# Determine if we just should set a reference or a full copy
if isinstance(mi, ClasswiseWrapper): # << Added
setattr(mi.metric, state, deepcopy(m0_state) if copy else m0_state) # << Added
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
if isinstance(mi, ClasswiseWrapper): # << Added
mi.metric._update_count = deepcopy(m0._update_count) if copy else m0._update_count # << Added
mi.metric._computed = deepcopy(m0._computed) if copy else m0._computed # << Added
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
mi._computed = deepcopy(m0._computed) if copy else m0._computed
self._state_is_copy = copy
Solution 2
https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/wrappers/classwise.py#L27
class ClasswiseWrapper:
def __getattr__(self, name: str):
# return state from self.metric
if name in ["tp", "fp", "fn", "tn"]: # <<Added
return getattr(self.metric, name) # <<Added
return super().__getattr__(name)
def __setattr__(self, name: str, value: Any) -> None:
if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn"]: # <<Added
setattr(self.metric, name, value) # <<Added
else: # <<Added
super().__setattr__(name, value)
if name == "metric": # <<Added
self._defaults = self.metric._defaults # <<Added
Environment
- TorchMetrics version (and how you installed TM, e.g.
conda
,pip
, build from source): 1.3.1 - Python & PyTorch Version (e.g., 1.0):
- Python: 3.8
- Pytorch: 2.1.1
- Any other relevant information such as OS (e.g., Linux): Windows 11
Additional context
Update Solution 2, which overrides the __getattr__
and __setattr__
of ClasswiseWrapper