dome272 / MaskGIT-pytorch

Pytorch implementation of MaskGIT: Masked Generative Image Transformer (https://arxiv.org/pdf/2202.04200.pdf)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Isn't loss only supposed to be calculated on masked tokens?

EmaadKhwaja opened this issue · comments

commented

In the training loop we have:

imgs = imgs.to(device=args.device)
logits, target = self.model(imgs)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
loss.backward()

However, the output of the transformer is:

  _, z_indices = self.encode_to_z(x)
.
.
.
  a_indices = mask * z_indices + (~mask) * masked_indices

  a_indices = torch.cat((sos_tokens, a_indices), dim=1)

  target = torch.cat((sos_tokens, z_indices), dim=1)

  logits = self.transformer(a_indices)

  return logits, target

which means the returned target is the original unmasked image tokens.

The MaskGIT paper seems to suggest that loss was only calculated on the masked tokens

image

I've attempted both strategies for a simple MaskGIT on CIFAR10 but the generation quality seems to still be bad. There are tricks that the authors are not telling us in the paper for their training scheme