your focal loss is wrong? it seems little different with others, can you explain your code?
Jack-zz-ze opened this issue · comments
class FocalLoss(nn.Module):
def init(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).init()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
they just use forward function, and you use forward and backforward ,can you explain it
Have you compared the outputs of the two implementations?
If you find my focal loss is wrong, please post an example code to show the difference between the correct implementation, and I will see where problem is in my code.
1、focal_loss = -α(1-pt)**γ log(pt)
2、your code:
log_probs = torch.where(logits >= 0,
F.softplus(logits, -1, 50),
logits - F.softplus(logits, 1, 50))
3、Do you want to show log(pt), but F.softplus is(1/β)log(1+e^(βx)).
log(pt)=labellog(p)+(1-label)*log(1-p),
but log_probs= -xlog(1+e^(-x)) + (1-x)[ x-log(1+e^(x)) ] is different
I have the same question, could you please the line?
log_probs = torch.where(logits >= 0,
F.softplus(logits, -1, 50),
logits - F.softplus(logits, 1, 50))
Just make some formula derivations, these implementations are totally the same and this version can be more stable. (Awesome work!) @noobgrow @Jack-zz-ze
close this because this is not active anymore.