nicolasrosa / Loss-Functions-For-Semantic-Segmentation

My own implementation for some sort of loss functions that have been used for segmentation task.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Losses Used in Segmentation Task

Image Segmentation can be defined as classification task on pixel level. An image consists of various pixels, and these pixels grouped together define different elements in image. A method of classifying these pixels into elements is called semantic segmentation.

The choice of loss/objective function is extremely important. In the paper, they summarized 15 segmentation based loss functions that have been proven to provide state-of-the-art results in different domains.

Table of loss functions:

Type Loss Function
Distribution-based Loss Binary Cross-Entropy
. Weighted Cross-Entropy
. Balanced Cross-Entropy
. Focal Loss
. Distance map derived loss penalty term
Region-based Loss Dice Loss
. Sensitivity-Specificity Loss
. Tversky Loss
. Focal Tversky Loss
. Log-Cosh Dice Loss
. Log-Cosh Dice Loss
Boundary-based Loss Hausdorff Distance loss
. Shape aware loss
Compounded Loss Combo Loss
. Exponential Logarithmic Loss

Optimizer is used to optimize and learn the Objective. To learn an objective accurately and faster, we need to ensure that the mathematical representation of objectives (aka loss function) are able to cover even the edge cases.

In the paper, the author focused on Semantic Segmentation instead of Instance Segmentation, so the number of classes at pixel level is restricted to 2.

Binary Cross-Entropy

Cross-entropy is defined as a measure of the difference between two probability distributions for a given random variable or set of events.

Usage: It is used for classification objective, and as segmentation is pixel level classification it works well.

Binary Cross-Entropy (BCE) is defined as:

In this case, we just have 2 classes. If more classes, then the formula become the sum of more terms, and the values inside log is result of softmax - which apply on tensor instead of sigmoid - which apply on a scalar.

Pytorch has the BCELoss in their built-in function. Read more at: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss

Notice: PyTorch use base e for the log function.

Multi-class case:

N: number of pixels need to classify in a minibatch

c : Notation for pixel

l: Notation for label, L is number of class we need to classify

$p^c$: Notation for probability vector of the predicted given by output of model. (Usually we use softmax after the output of model to get this)

$r^c$: Notation for one hot encoded vector, where 1 stand for class it belong and others are 0.

The relation between $l^c$ and $p^c$ when use model to predict:

Ok, we move to the next term.

Weighted Binary Cross-Entropy (WCE)

It is the variance of binary cross entropy. It is widely used in case of skewed data (the number of instance in each class is imbalance):

Multi-class case:

The tendency to under-estimate can be mitigated by assigning higher weights to loss contributions from pixels with underrepresented class labels (instance less than weight class hight)

class_weight computed in sklearn equivalent to term 1/w_c in the above equation:

One way to achieve the weight is taken from the one-hot $r^c$, example:

# Minibatch has size 20, we have 5 classes and in Pytorch it present by a Tensor contain index of labels instead of 
# One Hot tensor
import numpy as np
import sklearn.utils.class_weight as class_weight
import torch
import torch.nn as nn

y = torch.randint(0, 5, (20,))
class_weight = class_weight.compute_class_weight('balanced', np.unique(y), y.numpy())
class_weight = torch.tensor(class_weight, dtype=torch.float)
# Then pass this weight as the param for the CrossEntropyLoss, example

loss_fn = nn.CrossEntropyLoss(reduction='mean')
# For each minibatch 
# Compute the class weight by the code above, then change the weight by apply 
loss_fn.weight = torch.tensor(class_weight, dtype=torch.float)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()

And there are many way to achieve weight map like the one that introduced in the paper: https://arxiv.org/abs/1505.04597

Note: (I quite don't understand the note inside the paper)

Balanced Cross-Entropy (BCE)

It is similar to Weighted Cross Entropy. The only difference is that we also add weight to negative examples.

Focal Loss (Implemented)

Binary Classification Case:

This is also be seen as variation of Binary Cross-Entropy. It down-weights the contribution of easy examples and enables the model to focus more on learning hard examples.

Focal Loss proposes to down-weight easy examples and focus training on hard negatives using a modulating factor:

Here gamma > 0 and when gamma = 1. Focal Loss works like Cross Entropy Loss function. Similarly, alpha in range [0, 1]. It can be set by inverse class frequency or treated as a hyperparameter.

Multi-class Classification Case:


Dice Loss (Implemented)

Dice coefficient is widely used metric in computer vision to calculate the similarity between 2 image. Later in 2016, it has also adapted as loss function known as Dice Loss

Visualize for Dice Coefficient in set theory:

Binary classification:

Here 1 is added in numerator and denominator to ensure that the function is not undefined in edge case scenarios such as when .

Multi-class task:

This loss is introduced in V-Net (2016), called Soft Dice Loss: used to tackle the class imbalance without the need for explicit weighting (which is used in Weighted Cross Entropy). One possible formulation is:

Batch Soft Dice (This is a variance of Soft Dice) (Implemented but not sure)

Tversky Loss (Implemented)

Focal Tversky Loss (Implemented)

Sensitivity-Specificity Loss (Implemented)

Log-Cosh Dice Loss (Implemented)

Hausdorff Distance Loss (Need time to read more papers)

References:

Github: https://github.com/HaipengXiong/weighted-hausdorff-loss, Paper: https://arxiv.org/pdf/1806.07564.pdf

Blob loss

References:

Github: https://github.com/neuronflow/blob_loss, Paper: https://arxiv.org/abs/2205.08209

Shape aware loss


Combo Loss (Implemented)

Exponential Logarithmic Loss (Implemented)

Robust T-Loss

This loss will be used when tackle with the noisy annotation dataset. Inspired by the negative log-likelihood of the Student-t distribution.

References:

References:

TODO

  • Crop small image chunks for testing with the loss function, I need to be sure with the Hough loss, so I need to do that
  • Next version, base on Kornia library (https://github.com/kornia/kornia), I implemented the stable version that can apply to higher dimensional Tensor, that'll look like what the loss functions in Pytorch does.
  • Read papers about the rest loss functions and try hard to implement it.
  • Make a table to easy compare between them, when use these functions.
  • Take some of these functions into training process and test the current model and see how they improve the prediction performance.

About

My own implementation for some sort of loss functions that have been used for segmentation task.


Languages

Language:Python 100.0%