Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.

Home Page:

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`MetricCollection` did not copy inner state of metric in `ClasswiseWrapper` when computing groups metrics

daniel-code opened this issue Β· comments

πŸ› Bug

MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics.

To Reproduce

Steps to reproduce the behavior...

Code sample
import torch
from lightning import seed_everything
from torchmetrics import MetricCollection, ClasswiseWrapper
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score


random_pred = torch.rand((10, 3))
pred = torch.softmax(random_pred, dim=-1)
pred_class = torch.argmax(pred, dim=-1)
target = torch.randint(0, 3, size=(10,))

multiclass_acc = MulticlassAccuracy(
print("multiclass_acc:", multiclass_acc(pred, target))

multiclass_f1 = MulticlassF1Score(
print("multiclass_f1:", multiclass_f1(pred, target))

mc = MetricCollection(
        "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
        "f1": ClasswiseWrapper(MulticlassF1Score(num_classes=3, average=None)),
        ["accuracy", "f1"],

print("MetricCollection.forward:", mc(pred, target))
mc.update(pred, target)
print("MetricCollection.update&compute:", mc.compute())
Seed set to 42
site-packages\torchmetrics\utilities\ UserWarning: The ``compute`` method of metric MulticlassF1Score was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028
multiclass_acc: tensor([0.3333, 0.5000, 0.3333])
multiclass_f1: tensor([0.4000, 0.4444, 0.3333])
MetricCollection.forward: {'multiclassaccuracy_0': tensor(0.3333), 'multiclassaccuracy_1': tensor(0.5000), 'multiclassaccuracy_2': tensor(0.3333), 'multiclassf1score_0': tensor(0.4000), 'multiclassf1score_1': tensor(0.4444), 'multiclassf1score_2': tensor(0.3333)}
MetricCollection.update&compute: {'multiclassaccuracy_0': tensor(0.3333), 'multiclassaccuracy_1': tensor(0.5000), 'multiclassaccuracy_2': tensor(0.3333), 'multiclassf1score_0': tensor(0.), 'multiclassf1score_1': tensor(0.), 'multiclassf1score_2': tensor(0.)}

Expected behavior

The metrics multiclassf1score_0, multiclassf1score_1, and multiclassf1score_2 of MetricCollection.compute should be the same as MetricCollection.forward or the result of each metric.

Solution 1

class MetricCollection:
    def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
        """Create reference between metrics in the same compute group.

            copy: If `True` the metric state will between members will be copied instead
                of just passed by reference

        if not self._state_is_copy:
            for cg in self._groups.values():
                m0 = getattr(self, cg[0])
                for i in range(1, len(cg)):
                    mi = getattr(self, cg[i])
                    for state in m0._defaults:
                        m0_state = getattr(m0, state)
                        # Determine if we just should set a reference or a full copy
                        if isinstance(mi, ClasswiseWrapper):  # << Added
                            setattr(mi.metric, state, deepcopy(m0_state) if copy else m0_state) # << Added
                        setattr(mi, state, deepcopy(m0_state) if copy else m0_state)

                    if isinstance(mi, ClasswiseWrapper): # << Added
                        mi.metric._update_count = deepcopy(m0._update_count) if copy else m0._update_count # << Added
                        mi.metric._computed = deepcopy(m0._computed) if copy else m0._computed # << Added
                    mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
                    mi._computed = deepcopy(m0._computed) if copy else m0._computed
        self._state_is_copy = copy
Solution 2

class ClasswiseWrapper:
    def __getattr__(self, name: str):
        # return state from self.metric
        if name in ["tp", "fp", "fn", "tn"]:   # <<Added
            return getattr(self.metric, name)  # <<Added

        return super().__getattr__(name)

    def __setattr__(self, name: str, value: Any) -> None:
        if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn"]:  # <<Added
            setattr(self.metric, name, value)                             # <<Added
        else:                                                             # <<Added
            super().__setattr__(name, value)
            if name == "metric":                                          # <<Added
                self._defaults = self.metric._defaults                    # <<Added


  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 1.3.1
  • Python & PyTorch Version (e.g., 1.0):
    • Python: 3.8
    • Pytorch: 2.1.1
  • Any other relevant information such as OS (e.g., Linux): Windows 11

Additional context

Update Solution 2, which overrides the __getattr__ and __setattr__ of ClasswiseWrapper