enochkan / torch-metrics

Metrics for model evaluation in pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool


PyPI version License: MIT

Model evaluation metrics for PyTorch

Torch-metrics serves as a custom library to provide common ML evaluation metrics in Pytorch, similar to tf.keras.metrics.

As summarized in this issue, Pytorch does not have a built-in libary torch.metrics for model evaluation metrics. This is similar to the metrics library in PyTorch Lightning.


  • pip install --upgrade torch-metrics
from torch_metrics import Accuracy

## define metric ##
metric = Accuracy(from_logits=False)
y_pred = torch.tensor([1, 2, 3, 4])
y_true = torch.tensor([0, 2, 3, 4])

print(metric(y_pred, y_true))
## define metric ##
metric = Accuracy()
y_pred = torch.tensor([[0.2, 0.6, 0.1, 0.05, 0.05],
                       [0.2, 0.1, 0.6, 0.05, 0.05],
                       [0.2, 0.05, 0.1, 0.6, 0.05],
                       [0.2, 0.05, 0.05, 0.05, 0.65]])
y_true = torch.tensor([0, 2, 3, 4])

print(metric(y_pred, y_true))


Metrics from tf.keras.metrics and other metrics that are already implemented vs to-do

  • MeanSquaredError class
  • RootMeanSquaredError class
  • MeanAbsoluteError class
  • Precision class
  • Recall class
  • MeanIoU class
  • DSC class (Dice Similarity Coefficient)
  • F1Score class
  • RSquared class
  • Hinge class
  • SquaredHinge class
  • LogCoshError class
  • Accuracy class
  • KLDivergence class
  • CosineSimilarity class
  • AUC class
  • BinaryCrossEntropy class
  • CategoricalCrossEntropy class
  • SparseCategoricalCrossentropy class

Local Development and Testing

To quickly get started with local development, run:

make develop

To test, run:

python3 -m pytest

Pre-commit hooks

To run pre-commit against all files:

pre-commit run --all-files


Please raise issues or feature requests here. It will be extremely helpful if you comment on a specific issue before working on it. This provides visibility for others who also intend to work on the same issue. Reference any pull requests to their original issues.


Metrics for model evaluation in pytorch

License:MIT License


Language:Python 98.3%Language:Makefile 1.7%