Isn't loss only supposed to be calculated on masked tokens?
EmaadKhwaja opened this issue · comments
emaad commented
In the training loop we have:
imgs = imgs.to(device=args.device)
logits, target = self.model(imgs)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
loss.backward()
However, the output of the transformer is:
_, z_indices = self.encode_to_z(x)
.
.
.
a_indices = mask * z_indices + (~mask) * masked_indices
a_indices = torch.cat((sos_tokens, a_indices), dim=1)
target = torch.cat((sos_tokens, z_indices), dim=1)
logits = self.transformer(a_indices)
return logits, target
which means the returned target is the original unmasked image tokens.
The MaskGIT paper seems to suggest that loss was only calculated on the masked tokens
Darius commented
I've attempted both strategies for a simple MaskGIT on CIFAR10 but the generation quality seems to still be bad. There are tricks that the authors are not telling us in the paper for their training scheme