MetricWrapper for Target Binarization
lgienapp opened this issue Β· comments
π Feature
Add a TargetBinarizationWrapper
that cast continuous labels to binary labels given a threshold.
Motivation
Evaluating two metrics that require different label formats (e.g., one binary, the other continuous) is cumbersome since it requires setting up two different evaluation stacks where one is fed with binarized label data and the other is fed the original continuous data. This leads to code duplication. Also, persisting binarized labels into the dataset in scenarios where a metric requires different input than what is given in the ground-truth data diminishes code clarity w.r.t. the evaluation process.
Pitch
A metric wrapper that casts target data to binary targets during the .update()
and .forward()
methods. Can be applied to either a single Metric
, or a whole MetricCollection
.
Alternatives
- using a
MultiTaskWrapper
is possible, but has two caveats: (1) metrics with a different signature thanupdate(pred, target)
are not supported, and (2) it requires the user to implement the thresholding logic by themselves before feeding it into theMultiTaskWrapper
- use a more generic target processing wrapper that would allow supplying, e.g., a custom
lambda
that is applied to targets; more flexible, but also requires the user to implement their own logic. I think binarization is a common enough problem intorchmetrics
(since its metrics make a binary vs. non-binary distinction) to warrant its own wrapper.
Additional Information
Consider the following example of the desired behaviour:
import torch
from torchmetrics.wrappers import BinarizedTargetWrapper # <-- This does not exist
from torchmetrics.collections import MetricCollection
from torchmetrics.retrieval import RetrievalNormalizedDCG, RetrievalMRR
preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4])
targets = torch.tensor([1,0,0,0,0,2,1,0,0,0])
topics = torch.tensor([0,0,0,0,0,1,1,1,1,1])
metrics = MetricCollection({
"RetrievalNormalizedDCG": RetrievalNormalizedDCG(),
"RetrievalMRR": BinarizedTargetWrapper(RetrievalMRR(), threshold=0)), # <-- Enable this kind of metric composition which is not possible currently
})
metrics.update(preds, targets, indexes=topics)
metrics.compute()
If simple binarization as in the example is a desired solution, I have all the code needed for a pull request ready and can take on this issue.
Hi! thanks for your contribution!, great first issue!
Hi @lgienapp, thanks for opening this issue.
I am not against adding this feature, but maybe we should consider adding:
- an general
MetricInputTransformer
class where user can provide custom functions for transforming the input - then
BinarizedTargetWrapper
can be included as just a subclassMetricInputTransformer
with pre-selected transforms
how does that sound?
Establishing a general class sounds good. Just for clarification: the general MetricInputTransformer
would still subclass the WrapperMetric
to inherit all the "reset-sync" code from there (e.g., live in torchmetrics.wrappers.transformations
)? Or be its own thing (e.g., live in torchmetrics.transformations.abstract
), possibly duplicating the sync code from the wrapper base class?
In either case, I would propose an implementation like this (subclassing wrappers here), assuming that only positional params like preds
and targets
would be interesting to modify (and thus ignoring kwargs e.g. indices
):
class MetricInputTransformer(WrapperMetric):
def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: Any):
super().__init__(**kwargs)
self.wrapped_metric = wrapped_metric
def transform(self, *args) -> Tuple[torch.Tensor]:
raise NotImplementedError
def update(self, *args, **kwargs: Any) -> None:
self.wrapped_metric.update(*self.transform(*args), **kwargs)
def compute(self) -> Any:
return self.wrapped_metric.compute()
def forward(self, *args, **kwargs: Any) -> Any:
self.wrapped_metric.forward(*self.transform(*args), **kwargs)