High memory usage of Perplexity metric
nsmlzl opened this issue Β· comments
π Bug
I ran out of memory (GPU) when computing the perplexity metric and would like to propose a small optimization to decrease its memory utilization.
To Reproduce
For instance, when running the following code PyTorch tries to allocate 1024 GB of GPU memory on my system.
from torchmetrics.text import Perplexity
import torch
gen = torch.manual_seed(42)
preds = torch.rand(512, 1024, 12, generator=gen).cuda()
target = torch.randint(12, (512, 1024), generator=gen).cuda()
perp = Perplexity().cuda()
print(perp(preds, target))
Memory Inefficiency
I think the inefficiency is in this line:
probs[:, target]
results in a large temporary tensor with (512*1024)^2
elements. Afterwards only the diagonal values are used.
Potential Solution
In contrast
probs = probs[torch.arange(target.numel()), target][mask]
would only require memory of the size of target.
Would you consider accepting a pull request with this optimization? Or was the previous implementation chosen for another reason?
Environment
- TorchMetrics v1.2.1 (installed with pip) and Master branch.
- Python 3.10.12
- Pytorch 2.2.0
- CUDA 12.1
Hi! thanks for your contribution!, great first issue!