Questions about function forward() in NUWA please.
Fitzwong opened this issue · comments
I'm confused me that, in function forward() of class NUWA, the ground-truth video is fed to transformer and calculate the output video, which is different from function generate().
frame_embeddings = self.video_transformer(
frame_embeddings, # calculated from ground-truth video
context = text_embeds,
context_mask = text_mask
)
So when training NUWA, the loss comes from logits. But the logits are not only from text, but ground-truth video (only one transformer layer, different from the auto-regressive model in generate function). Is that some kind of cheating when training? Or should I generate logits in the same way as in generate(), and then calculate loss to train?
so the reason is because we compress the video into a sequence of tokens, and then we have each token predict the next token, autoregressively