msight-tech / research-ms-loss

MS-Loss: Multi-Similarity Loss for Deep Metric Learning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pytorch implementation of ms-loss

iGuaZi opened this issue · comments

class MultiSimilarityLoss(nn.Module):
    def __init__(self, configer=None):  
        super(MultiSimilarityLoss, self).__init__()
        self.is_norm = True
        self.eps = 0.1
        self.lamb = 1
        self.alpha = 2
        self.beta = 50
        
            
    def forward(self, inputs, targets):
        n = inputs.size(0)
        if self.is_norm:
            inputs = inputs / torch.norm(inputs, dim=1, keepdim=True)
        similari_matrix = inputs.matmul(inputs.t())
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        
        loss = None
        for i in range(n):
            temp_sim, temp_mask = similari_matrix[i], mask[i]
            min_ap, max_an = temp_sim[temp_mask].min(), temp_sim[temp_mask==0].max()
            temp_AP = temp_sim[(temp_mask==1) & (temp_sim < max_an + self.eps)]       # may be tensor([])
            temp_AN = temp_sim[(temp_mask==0) & (temp_sim > min_ap - self.eps)]  # torch.sum(tensor([])) = tensor(0.)
            L1 = torch.log(1 + torch.sum(torch.exp(-self.alpha * (temp_AP - self.lamb)))) / self.alpha
            L2 = torch.log(1 + torch.sum(torch.exp(self.beta * (temp_AN - self.lamb)))) / self.beta
            L = L1 + L2
            if loss is None:
                loss = L
            else:
                loss += L
        loss /= n

        return loss  

Thanks @iGuaZi for posting your code. Our official implementation is now available in this repo, you may examine it to verify.