CoinCheung / pytorch-loss

label-smooth, amsoftmax, partial-fc, focal-loss, triplet-loss, lovasz-softmax. Maybe useful

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Calculation about focal loss

BinsonW opened this issue · comments

Thanks for your great works!

As far as I know, the focal loss is $ FL = -y(1-p)^\gamma \log{p} - (1-y)p^\gamma \log{1-p} $, but the code in the focal_loss.py:

        coeff = torch.abs(label - probs).pow(self.gamma).neg()
        log_probs = torch.where(logits >= 0,
                F.softplus(logits, -1, 50),
                logits - F.softplus(logits, 1, 50))
        log_1_probs = torch.where(logits >= 0,
                -logits + F.softplus(logits, -1, 50),
                -F.softplus(logits, 1, 50))
        loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs
        loss = loss * coeff

Seems like the $ \log{p} $ in the focal loss is implemented by the function F.softplus()? But the F.softplus() comes out with $ y=\log{1+e^x} $, which is not $ \log{x} $ .

How should I understand this difference? Or did I miss something necessary about focal loss?

That is math formula derivation in order to make expression such as log(exp + 1) more stable.

Also please do not neglect sigmoid function. It is also included in the loss function, which is same as nn.BCEWithLogitsLoss.