Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains

Home Page:https://lightning-flash.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support ClasswiseWrapper Metrics for Classification Tasks

newzealandpaul opened this issue Β· comments

commented

πŸš€ Feature

Currently torchmetrics ClasswiseWrapper, which allows for per-class metrics, is not supported by Lightning.

Motivation

Per-class metrics are essential for many classification tasks, to give insight into model performance.

Pitch

Currently passing ClasswiseWrapper() metrics when creating a new instance of a Lightning model causes an error in flash/core/model.py:373 because ClasswiseWrapper objects do not have a _forward_cache attribute. Fixing that, causes an error in trainer/connectors/logger_connector/result.py:548 as it expects a tensor not a dict of tensors.

Users would expect that torchmetric features are natively supported.