1Konny / Beta-VAE

Pytorch implementation of β-VAE

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

β-VAE

Pytorch reproduction of two papers below:

  1. β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework, Higgins et al., ICLR, 2017
  2. Understanding disentangling in β-VAE, Burgess et al., arxiv:1804.03599, 2018

Dependencies

python 3.6.4
pytorch 0.3.1.post2
visdom

Datasets

same with here

Usage

initialize visdom

python -m visdom.server

you can reproduce results below by

sh run_celeba_H_beta10_z10.sh
sh run_celeba_H_beta10_z32.sh
sh run_3dchairs_H_beta4_z10.sh
sh run_3dchairs_H_beta4_z16.sh
sh run_dsprites_B_gamma100_z10.sh

or you can run your own experiments by setting parameters manually.
for objective and model arguments, you have two options H and B indicating methods proposed in Higgins et al. and Burgess et al., respectively.
arguments --C_max and --C_stop_iter should be set when --objective B. for further details, please refer to Burgess et al.

e.g.
python main.py --dataset 3DChairs --beta 4 --lr 1e-4 --z_dim 10 --objective H --model H --max_iter 1e6 ...
python main.py --dataset dsprites --gamma 1000 --C_max 25 --C_stop_iter 1e5 --lr 5e-4 --z_dim 10 --objective B --model B --max_iter 1e6 ...

check training process on the visdom server

localhost:8097

Results

3D Chairs

sh run_3dchairs_H_beta4_z10.sh

3dchairs_beta4_z16

sh run_3dchairs_H_beta4_z16.sh

3dchairs_beta4_z16

CelebA

sh run_celeba_H_beta10_z10.sh

celeba

sh run_celeba_H_beta10_z32.sh

celeba

dSprites

sh run_dsprites_B.sh
visdom line plot

dsprites_plot

latent traversal gif(--save_output True)

##### reconstruction(left: true, right: reconstruction)

Reference

  1. β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework, Higgins et al., ICLR, 2017
  2. Understanding disentangling in β-VAE, Burgess et al., arxiv:1804.03599, 2018
  3. Github Repo: Tensorflow implementation from miyosuda

About

Pytorch implementation of β-VAE

License:MIT License


Languages

Language:Python 96.3%Language:Shell 3.7%