zzp1012 / realNVP

PyTorch implementation of realNVP

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

realNVP Custom

A PyTorch implementation of the training procedure of Density Estimation Using Real NVP. The original implementation in TensorFlow can be found at https://github.com/tensorflow/models/tree/master/research/real_nvp.

Imlementation Details

This implementation supports training on four datasets, namely CIFAR-10, CelebA, ImageNet 32x32 and ImageNet 64x64. For each dataset, only the training split is used for learning the distribution. Labels are left untouched. Raw data is subject to dequantization, random horizontal flipping and logit transformation (see the paper for details). The network architecture is faithfully reproduced. The same set of hyperparameters as suggested by the paper is set as default. Adam with default parameters are used for optimization. Model performance, evaluated by bits/dim, matches what was reported in the paper.

Samples

The samples are generated from models trained with default parameters. Each iteration corresponds to a minibatch of 64 images.

CIFAR-10

1000 iterations

80000 iterations

CelebA

1000 iterations

60000 iterations

ImageNet 32x32

1000 iterations

80000 iterations

ImageNet 64x64

1000 iterations

60000 iterations

Training

Code runs on a single GPU and has been tested with

  • Python 3.7.2
  • torch 1.0.0
  • numpy 1.15.4
python train.py --dataset=cifar10 --batch_size=64 --base_dim=64 --res_blocks=8 --max_iter=80000
python train.py --dataset=celeba --batch_size=64 --base_dim=32 --res_blocks=2 --max_iter=60000
python train.py --dataset=imnet32 --batch_size=64 --base_dim=32 --res_blocks=4 --max_iter=80000
python train.py --dataset=imnet64 --batch_size=64 --base_dim=32 --res_blocks=2 --max_iter=60000 

About

PyTorch implementation of realNVP

License:MIT License


Languages

Language:Python 100.0%