`log_dict` method breaks on `MultitaskWrapper` + `MetricCollection` combination
ponderbb opened this issue Β· comments
π Bug
I am trying to log metrics with the MultitaskWrapper
and MetricCollection
combination, which analogous to the test case that can be found in integrations/test_lightning
.
Logging this configuration of metrics with Lightning's built in log_dict
method throws a ValueError
(see in the reproduction snippet), which is confusing as logging either MultitaskWrapper
or MetricCollection
individually works with the log_dict
method and this exact logging approach is presented in the integration tests.
To Reproduce
Steps to reproduce the behavior...
Code sample
The multitask configuration from integrations/test_lightning::test_task_wrapper_lightning_logging
:
self.multitask_collection = MultitaskWrapper({
"classification": MetricCollection([BinaryAccuracy(), BinaryAveragePrecision()]),
"regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
})
The log_dict
method splits the MultitaskWrapper
into the individual tasks, however it tries to call the log
method on the MetricCollection
, which breaks with the following ValueError
:
E ValueError: `self.log(classification, MetricCollection(
E (BinaryAccuracy): BinaryAccuracy()
E (BinaryAveragePrecision): BinaryAveragePrecision()
E ))` was called, but `MetricCollection` values cannot be logged
This ValueError
is not caught when running the test, as it seems to be silenced by the no_warning_call
.
Expected behavior
log_dict
working for the MultitaskWrapper
+ MetricCollection
combination or some warning that it is not possible, as this way the corresponding testcase is misleading.
Environment
Environment
- CentOS
- Python 3.11
- torchmetrics == 1.3.0.post0 (with poetry)
- torch == 2.2.0.post0 (with poetry)