lxuechen / BDMC

PyTorch implementation of Bidirectional Monte Carlo, Annealed Importance Sampling, and Hamiltonian Monte Carlo.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error

XintianHan opened this issue · comments

Traceback (most recent call last):
File "bdmc.py", line 67, in
main()
File "bdmc.py", line 63, in main
bdmc(model, loader, forward_schedule=np.linspace(0., 1., 500), n_sample=100)
File "bdmc.py", line 33, in bdmc
forward_logws = ais_trajectory(model, load, mode='forward', schedule=forward_schedule, n_sample=n_sample)
File "../BDMC/ais.py", line 76, in ais_trajectory
log_int_1 = log_f_i(current_z, batch, t0)
File "../BDMC/ais.py", line 42, in log_f_i
log_likelihood = discretized_logistic(*model.decode(z), data)
TypeError: discretized_logistic() takes 3 positional arguments but 10001 were given

I got this error simply running your code. How can I fix it?

Hi, thanks for reaching out. The line

log_likelihood = discretized_logistic(*model.decode(z), data)

is a version that I used to test out behavior on CIFAR, which turned out to be numerically very unstable and slow. The stable version uses the uncommented line above and runs on a model trained on MNIST/Fashion. I have updated the repo to reflect this.

Pin me if you have any other questions. Thanks.

I see. Thanks! Do you have some intuition why your code does not work well on cifar10?

CIFAR is quite a difficult dataset to model with VAEs and even generative models on a whole. It's not surprising for example that achieving good marginal log-likelihood (MLL) values requires hierarchical latent variable models and fancy inference methods like ResNetVAE (https://arxiv.org/abs/1606.04934) that is trained on multiple GPUs. This is mostly because the observation space has too many dimensions compared to MNIST/Fashion.

For such trained models it's also quite difficult to accurately sandwich the MLL with AIS and reverse AIS in any reasonable amount of time with reasonable amount of compute.

As a side note, the implementation in the repo should work on pytorch v0.2.0, but not guaranteed with any versions higher due to drastic API changes of pytorch.

It helps a lot! I really appreciate it.