audiolabs / torch-pesq

PyTorch implementation of the Perceptual Evaluation of Speech Quality for wideband audio

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NameError: name 'batch' is not defined.

MorenoLaQuatra opened this issue · comments

Hi,

First of all thank you for the contribution and for making the code available.
I had a problem while computing the loss. I tried:

import torch
from torch_pesq import PesqLoss
pesq = PesqLoss(0.5, sample_rate=44100)

 t1 = torch.rand((88200))
 t2 = torch.rand((88200))
 
 loss = pesq(t1, t2)

I got the following error:

File ~/miniconda3/envs/audio/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/audio/lib/python3.10/site-packages/torch_pesq/loss.py:301, in forward(self, ref, deg)
    295     mos = 0.999 + 4 / (1 + torch.exp(-1.3669 * mos + 3.8224))
    297     return mos
    299 @typechecked
    300 def forward(
--> 301     self, ref: TensorType["batch", "sample"], deg: TensorType["batch", "sample"]
    302 ) -> TensorType["batch", "sample"]:
    303     """Calculate a loss variant of the MOS score
    304 
    305     This function combines symmetric and asymmetric distances but does not apply a range
   (...)
    318         Loss value in range [0, inf)
    319     """
    320     d_symm, d_asymm = self.raw(ref, deg)

NameError: name 'batch' is not defined

Most probably I'm not interpreting it correctly, do you have any suggestion?

not sure what's going wrong there, perhaps @nils-werner knows more. Can you also provide the version of typeguard package?

Sure, here the packages that may be relevant:

torch==2.0.1
torch-pesq==0.1.2
torchaudio==2.0.2
torchtyping==0.1.4
typeguard==4.0.0

fixed typeguard version to 2.* in #6