Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.

Home Page:https://lightning.ai/docs/torchmetrics/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MetricWrapper for Target Binarization

lgienapp opened this issue Β· comments

πŸš€ Feature

Add a TargetBinarizationWrapper that cast continuous labels to binary labels given a threshold.

Motivation

Evaluating two metrics that require different label formats (e.g., one binary, the other continuous) is cumbersome since it requires setting up two different evaluation stacks where one is fed with binarized label data and the other is fed the original continuous data. This leads to code duplication. Also, persisting binarized labels into the dataset in scenarios where a metric requires different input than what is given in the ground-truth data diminishes code clarity w.r.t. the evaluation process.

Pitch

A metric wrapper that casts target data to binary targets during the .update() and .forward() methods. Can be applied to either a single Metric, or a whole MetricCollection.

Alternatives

  • using a MultiTaskWrapper is possible, but has two caveats: (1) metrics with a different signature than update(pred, target) are not supported, and (2) it requires the user to implement the thresholding logic by themselves before feeding it into the MultiTaskWrapper
  • use a more generic target processing wrapper that would allow supplying, e.g., a custom lambda that is applied to targets; more flexible, but also requires the user to implement their own logic. I think binarization is a common enough problem in torchmetrics (since its metrics make a binary vs. non-binary distinction) to warrant its own wrapper.

Additional Information

Consider the following example of the desired behaviour:

import torch
from torchmetrics.wrappers import BinarizedTargetWrapper # <-- This does not exist
from torchmetrics.collections import MetricCollection
from torchmetrics.retrieval import RetrievalNormalizedDCG, RetrievalMRR 

preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4])
targets = torch.tensor([1,0,0,0,0,2,1,0,0,0])
topics = torch.tensor([0,0,0,0,0,1,1,1,1,1])

metrics = MetricCollection({
    "RetrievalNormalizedDCG": RetrievalNormalizedDCG(),
    "RetrievalMRR": BinarizedTargetWrapper(RetrievalMRR(), threshold=0)), # <-- Enable this kind of metric composition which is not possible currently
})

metrics.update(preds, targets, indexes=topics)
metrics.compute()

If simple binarization as in the example is a desired solution, I have all the code needed for a pull request ready and can take on this issue.

Hi! thanks for your contribution!, great first issue!

Hi @lgienapp, thanks for opening this issue.
I am not against adding this feature, but maybe we should consider adding:

  • an general MetricInputTransformer class where user can provide custom functions for transforming the input
  • then BinarizedTargetWrapper can be included as just a subclass MetricInputTransformer with pre-selected transforms

how does that sound?

Establishing a general class sounds good. Just for clarification: the general MetricInputTransformer would still subclass the WrapperMetric to inherit all the "reset-sync" code from there (e.g., live in torchmetrics.wrappers.transformations)? Or be its own thing (e.g., live in torchmetrics.transformations.abstract), possibly duplicating the sync code from the wrapper base class?

In either case, I would propose an implementation like this (subclassing wrappers here), assuming that only positional params like preds and targets would be interesting to modify (and thus ignoring kwargs e.g. indices):

class MetricInputTransformer(WrapperMetric):
  
    def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: Any):
        super().__init__(**kwargs)
        self.wrapped_metric = wrapped_metric

    def transform(self, *args) -> Tuple[torch.Tensor]:
        raise NotImplementedError

    def update(self, *args, **kwargs: Any) -> None:
        self.wrapped_metric.update(*self.transform(*args), **kwargs)

    def compute(self) -> Any:
        return self.wrapped_metric.compute()

    def forward(self, *args, **kwargs: Any) -> Any:
        self.wrapped_metric.forward(*self.transform(*args), **kwargs)