Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.

Home Page:https://lightning.ai/docs/torchmetrics/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

High memory usage of Perplexity metric

nsmlzl opened this issue Β· comments

πŸ› Bug

I ran out of memory (GPU) when computing the perplexity metric and would like to propose a small optimization to decrease its memory utilization.

To Reproduce

For instance, when running the following code PyTorch tries to allocate 1024 GB of GPU memory on my system.

from torchmetrics.text import Perplexity
import torch

gen = torch.manual_seed(42)
preds = torch.rand(512, 1024, 12, generator=gen).cuda()
target = torch.randint(12, (512, 1024), generator=gen).cuda()

perp = Perplexity().cuda()
print(perp(preds, target))

Memory Inefficiency

I think the inefficiency is in this line:

probs = probs[:, target].diagonal()[mask]

probs[:, target] results in a large temporary tensor with (512*1024)^2 elements. Afterwards only the diagonal values are used.

Potential Solution

In contrast

probs = probs[torch.arange(target.numel()), target][mask]

would only require memory of the size of target.

Would you consider accepting a pull request with this optimization? Or was the previous implementation chosen for another reason?

Environment

  • TorchMetrics v1.2.1 (installed with pip) and Master branch.
  • Python 3.10.12
  • Pytorch 2.2.0
  • CUDA 12.1

Hi! thanks for your contribution!, great first issue!

Just created PR #2346 with the (small) change. Feel free to merge, when you like it.