lucidrains / nuwa-pytorch

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Questions about function forward() in NUWA please.

Fitzwong opened this issue · comments

I'm confused me that, in function forward() of class NUWA, the ground-truth video is fed to transformer and calculate the output video, which is different from function generate().

frame_embeddings = self.video_transformer(
            frame_embeddings,  # calculated from ground-truth video
            context = text_embeds,
            context_mask = text_mask
        )

So when training NUWA, the loss comes from logits. But the logits are not only from text, but ground-truth video (only one transformer layer, different from the auto-regressive model in generate function). Is that some kind of cheating when training? Or should I generate logits in the same way as in generate(), and then calculate loss to train?

so the reason is because we compress the video into a sequence of tokens, and then we have each token predict the next token, autoregressively