HanxunH / SCELoss-Reproduce

Reproduce Results for ICCV2019 "Symmetric Cross Entropy for Robust Learning with Noisy Labels" https://arxiv.org/abs/1908.06112

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error when predictions have extra dimensions other than classes

confifu opened this issue · comments

torch.nn.CrossEntropyLoss takes prediction of shape (minibatch, Classes, d1, d2 , ...). This line leads to error if the d1, d2 dimensions exist.

rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))

The predictions are of shape (minibatch, classes, d1, d2, ...) but the label_one_hot variable is of shape (minibatch, d1, d2, ..., classes). I think the label_one_hot tensor should be transposed to make the class dimension the second dimension.

Hi,
Thanks for the suggestion! Feel free to open a PR.