Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.

Home Page:https://lightning.ai/docs/torchmetrics/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`log_dict` method breaks on `MultitaskWrapper` + `MetricCollection` combination

ponderbb opened this issue Β· comments

πŸ› Bug

I am trying to log metrics with the MultitaskWrapper and MetricCollection combination, which analogous to the test case that can be found in integrations/test_lightning.

Logging this configuration of metrics with Lightning's built in log_dict method throws a ValueError (see in the reproduction snippet), which is confusing as logging either MultitaskWrapper or MetricCollection individually works with the log_dict method and this exact logging approach is presented in the integration tests.

To Reproduce

Steps to reproduce the behavior...

Code sample

The multitask configuration from integrations/test_lightning::test_task_wrapper_lightning_logging:

self.multitask_collection = MultitaskWrapper({
    "classification": MetricCollection([BinaryAccuracy(), BinaryAveragePrecision()]),
    "regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
})

The log_dict method splits the MultitaskWrapper into the individual tasks, however it tries to call the log method on the MetricCollection, which breaks with the following ValueError:

E       ValueError: `self.log(classification, MetricCollection(
E         (BinaryAccuracy): BinaryAccuracy()
E         (BinaryAveragePrecision): BinaryAveragePrecision()
E       ))` was called, but `MetricCollection` values cannot be logged

This ValueError is not caught when running the test, as it seems to be silenced by the no_warning_call.

Expected behavior

log_dict working for the MultitaskWrapper + MetricCollection combination or some warning that it is not possible, as this way the corresponding testcase is misleading.

Environment

Environment

  • CentOS
  • Python 3.11
  • torchmetrics == 1.3.0.post0 (with poetry)
  • torch == 2.2.0.post0 (with poetry)

Hi @ponderbb, thanks for raising this issue and sorry for my slow reply.
It would seem to me that this was actually fixed in PR #2349 which is part of the v1.3.1 release of torchmetrics. It should therefore work if you update to the newest torchmetrics.
Closing issue.