soft_cross_entropy in TinyBERT needs modification
sbwww opened this issue · comments
B. Shen commented
The soft_cross_entropy
loss function in TinyBERT, DynaBERT and other distilled models seems inaccurate. In the paper, it is said to be CE
does.
def soft_cross_entropy(predicts, targets):
student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1)
targets_prob = torch.nn.functional.softmax(targets, dim=-1)
return (- targets_prob * student_likelihood).mean()
The mean()
here simply averaged ALL the values in tensor, but we want to sum up the cross entropy value of each sample and then average them. So the soft_cross_entropy
in the code is actually scaled down by num_labels
!!
I think it is better to use the code below as an alternate.
return (- targets_prob * student_likelihood).sum(dim=-1).mean()