c0nn3r / RetinaNet

An implementation of RetinaNet in PyTorch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

I think focal_loss.py is wromg

doiken23 opened this issue · comments

Hi.
I refer your code for my focal loss project, but I think there is mistake in focal_loss.py .

cross_entropy = F.cross_entropy(output, target)
cross_entropy_log = torch.log(cross_entropy)

should be

logpt = -F.cross_entropy(input, target)
pt = torch.exp(logpt) 

Please confirm it.
Best regards.

I agree with you. The "cross_entropy" in PyTorch contains "log_softmax" , which is a softmax followed by a logarithm. So, the "pt" in Focal Loss paper should be exp(- cross_entropy).

I agree with you too.In the paper,The focal loss is -a*(1-pt)^b*log(pt). if p is the predict probability of a sample and it's real label is 1,the pt should be p,else 1-p.

so
logpt = -F.cross_entropy(input, target)
pt = torch.exp(logpt)