Error when predictions have extra dimensions other than classes
confifu opened this issue · comments
torch.nn.CrossEntropyLoss takes prediction of shape (minibatch, Classes, d1, d2 , ...). This line leads to error if the d1, d2 dimensions exist.
Line 23 in c81fb33
The predictions are of shape (minibatch, classes, d1, d2, ...) but the
label_one_hot
variable is of shape (minibatch, d1, d2, ..., classes). I think the label_one_hot
tensor should be transposed to make the class dimension the second dimension.Hi,
Thanks for the suggestion! Feel free to open a PR.