JSD loss implementation does not seem to match the formula in your paper
mahdi7 opened this issue · comments
Mahdi commented
I am having trouble figuring out how the following implementation matches the softplus version of JSD described in equation (4) of Appendix F of your paper. I would really appreciate if you can provide any clarification.
def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs):
num_neg = neg_mask.int().sum()
num_pos = pos_mask.int().sum()
similarity = self.discriminator(anchor, sample)'
E_pos = (np.log(2) - F.softplus(- similarity * pos_mask)).sum()
E_pos /= num_pos
neg_sim = similarity * neg_mask
E_neg = (F.softplus(- neg_sim) + neg_sim - np.log(2)).sum()
E_neg /= num_neg
return E_neg - E_pos
Mahdi commented
Any comment or response on this issue?