ajabri / videowalk

Repository for "Space-Time Correspondence as a Contrastive Random Walk" (NeurIPS 2020)

Home Page:http://ajabri.github.io/videowalk

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cross-entropy loss computation question

vadimkantorov opened this issue · comments

@ajabri The paper specifies that the loss is cross-entropy between the row-normalized cycle transition matrix and the identity matrix:
image

However, the code seems to compute something slightly different:
https://github.com/ajabri/videowalk/blob/0834ff9/code/model.py#L175-L176:

# self.xent = nn.CrossEntropyLoss(reduction="none")
logits = torch.log(A+EPS).flatten(0, -2)
loss = self.xent(logits, target).mean()

where matrix A is row-stochastic.

CrossEntropyLoss module expects unnormalized logits and does log-softmax directly. This is like computing log_softmax(log(P[i]))[i] - and this is not regular cross-entropy which would have been log(P[i])[i]. Should nn.NLLLoss have been used instead?

The code seems to use log-probs in place of logits (by logits I mean raw unnormalized scores). Is this intentional? If not it might be a bug. @ajabri Could you please comment on this.

Thank you!

The softmax function is invariant to constant translation of logits, and the result is thus the same.

The log-softmax function returns the log of the softmax. So logsoftmax(log(A)) = log(A). In that sense, it is wasteful to use the xent module, and we should just use nll. I guess the xent is there for legacy reasons, as I was experimenting with different losses earlier.

logsoftmax(log(A)) = log(A)

logsoftmax(log(A))[i] = log(A_i) - log(sum_j(exp(log(A_j))) = log(A_i) - log(sum_j A_j) = log(A_i)

This would be correct if we did not have EPS adjustment (we're relying on sum_j A_j = 1), but that should not be much important, right?

Yes, I don't think it is too problematic.