asteroid-team / asteroid

The PyTorch-based audio source separation toolkit for researchers

Home Page:https://asteroid-team.github.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

deep_clustering_loss example code is broken

geajack opened this issue Β· comments

πŸ“š Documentation

The example code for the affinity loss on this page seems to be broken:

import torch
from asteroid.losses.cluster import deep_clustering_loss
spk_cnt = 3
embedding = torch.randn(10, 5*400, 20)
targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt)
loss = deep_clustering_loss(embedding, targets)

This results in

Traceback (most recent call last):
  File "loss.py", line 6, in <module>
    loss = deep_clustering_loss(embedding, targets)
  File "/home/jack/.local/lib/python3.8/site-packages/asteroid/losses/cluster.py", line 38, in deep_clustering_loss
    batch, bins, frames = tgt_index.shape
ValueError: not enough values to unpack (expected 3, got 1)

PyTorch version: 1.10.2
Asteroid version: 0.5.2

Thanks for reporting this bug !

Would you like to fix it? If you look at the unit tests, you'll probably find a good example of how it's used πŸ˜‰

Hi geajack. The problem is in the following line:

targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt)

A list of three elements is passed while the function expects to have three separate arguments, like this:

targets = torch.LongTensor(10, 400, 5).random_(0, spk_cnt)

I'll be submitting a PR for this issue, if you don't mind :)