Can you explain this math part?
algoteam5 opened this issue · comments
Can you explain the following math part?
if args.score == 'energy':
Ec_out = -torch.logsumexp(x[len(in_set[0]):], dim=1)
Ec_in = -torch.logsumexp(x[:len(in_set[0])], dim=1)
loss += 0.1*(torch.pow(F.relu(Ec_in-args.m_in), 2).mean() + torch.pow(F.relu(args.m_out-Ec_out), 2).mean())
The 2nd and 3rd lines are computing the energy score. The last line is computing the mean-square error between the bounds and the energy score so that the samples with energy score between energy gap m_in and m_out will be push out of the gap.