huawei-noah / Pretrained-Language-Model

Pretrained language model and its related optimization techniques developed by Huawei Noah's Ark Lab.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

soft_cross_entropy in TinyBERT needs modification

sbwww opened this issue · comments

The soft_cross_entropy loss function in TinyBERT, DynaBERT and other distilled models seems inaccurate. In the paper, it is said to be $CE(z^T/t, z^S/t)$, but the code is not functioning as what CE does.

def soft_cross_entropy(predicts, targets):
    student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1)
    targets_prob = torch.nn.functional.softmax(targets, dim=-1)
    return (- targets_prob * student_likelihood).mean()

The mean() here simply averaged ALL the values in tensor, but we want to sum up the cross entropy value of each sample and then average them. So the soft_cross_entropy in the code is actually scaled down by num_labels!!

I think it is better to use the code below as an alternate.

return (- targets_prob * student_likelihood).sum(dim=-1).mean()