How could we get the "p0_z" for the kl loss?
XiaoyanQian opened this issue · comments
Xiaoyan Qian commented
Hi, could you help to figure out the p0_z for KL loss in the following code:
kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
loss = kl.mean()
How can I get the p0_z? Any thoughts?