CoinCheung / pytorch-loss

label-smooth, amsoftmax, partial-fc, focal-loss, triplet-loss, lovasz-softmax. Maybe useful

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

your focal loss is wrong? it seems little different with others, can you explain your code?

Jack-zz-ze opened this issue · comments

class FocalLoss(nn.Module):
def init(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).init()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce

def forward(self, inputs, targets):
    if self.logits:
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
    else:
        BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
        return torch.mean(F_loss)
    else:
        return F_loss

they just use forward function, and you use forward and backforward ,can you explain it

Have you compared the outputs of the two implementations?

If you find my focal loss is wrong, please post an example code to show the difference between the correct implementation, and I will see where problem is in my code.

1、focal_loss = -α(1-pt)**γ log(pt)
2、your code:
log_probs = torch.where(logits >= 0,
F.softplus(logits, -1, 50),
logits - F.softplus(logits, 1, 50))
3、Do you want to show log(pt), but F.softplus is(1/β)log(1+e^(βx)).
log(pt)=label
log(p)+(1-label)*log(1-p),
but log_probs= -xlog(1+e^(-x)) + (1-x)[ x-log(1+e^(x)) ] is different

commented

I have the same question, could you please the line?
log_probs = torch.where(logits >= 0,
F.softplus(logits, -1, 50),
logits - F.softplus(logits, 1, 50))

commented

Just make some formula derivations, these implementations are totally the same and this version can be more stable. (Awesome work!) @noobgrow @Jack-zz-ze

close this because this is not active anymore.