MeteoSwiss / ldcast

Latent diffusion for generative precipitation nowcasting

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Some questions about VAE

clearlyzero opened this issue · comments

thank you for a very good job! I have a question about VAE;
In autoenc.py 48-52

    def _loss(self, batch):
        (x,y) = batch
        while isinstance(x, list) or isinstance(x, tuple):
            x = x[0][0]
        (y_pred, mean, log_var) = self.forward(x)

        rec_loss = (y-y_pred).abs().mean()
        kl_loss = kl_from_standard_normal(mean, log_var)

        total_loss = rec_loss + self.kl_weight * kl_loss

        return (total_loss, rec_loss, kl_loss)

(y_pred, mean, log_var) = self.forward(x)

I'm a little confused here

(x,y) = batch
while isinstance(x, list) or isinstance(x, tuple):
            x = x[0][0]
(y_pred, mean, log_var) = self.forward(x)
(x,y) = batch
while isinstance(x, list) or isinstance(x, tuple):
            x = x[0][0]
(y_pred, mean, log_var) = self.forward(y)

Is it self.forward(y) or self.forward(x)? Is the shape of x here representing the 4 frames of the condition? If this is the number of 4 frames of the condition, then y is the number of frames to be predicted. Which should be used here? self.forward(y)?