I think focal_loss.py is wromg
doiken23 opened this issue · comments
Kento Doi commented
Hi.
I refer your code for my focal loss project, but I think there is mistake in focal_loss.py .
cross_entropy = F.cross_entropy(output, target) cross_entropy_log = torch.log(cross_entropy)
should be
logpt = -F.cross_entropy(input, target) pt = torch.exp(logpt)
Please confirm it.
Best regards.
DonGovi commented
I agree with you. The "cross_entropy" in PyTorch contains "log_softmax" , which is a softmax followed by a logarithm. So, the "pt" in Focal Loss paper should be exp(- cross_entropy).
pkwangwanjun commented
I agree with you too.In the paper,The focal loss is -a*(1-pt)^b*log(pt). if p is the predict probability of a sample and it's real label is 1,the pt should be p,else 1-p.
so
logpt = -F.cross_entropy(input, target)
pt = torch.exp(logpt)