shidilrzf / Wasserstein2GenerativeNetworks

PyTorch implementation of the paper "Wasserstein-2 Generative Networks".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Wasserstein-2 Generative Networks

This is the official Python implementation of the paper Wasserstein-2 Generative Networks (preprint on arXiv) by Alexander Korotin, Vahe Egizarian, Arip Asadulaev, Alexander Safin and Evgeny Burnaev.

The repository contains reproducible PyTorch source code for computing optimal transport maps (and distances) in high dimensions via the end-to-end non-minimax method (proposed in the paper) by using input convex neural networks. Examples are provided for various real-world problems: color transfer, latent space mass transport, domain adaptation, style transfer.

Prerequisites

The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Main prerequisites are:

Repository structure

All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/). For convenience, the majority of the evaluation output is preserved. Auxilary source code is moved to .py modules (src/).

Experiments

  • notebooks/W2GN_toy_experiments.ipynb -- toy experiments (2D: Swiss Roll, 100 Gaussuans, ...);
  • notebooks/W2GN_gaussians_high_dimensions.ipynb -- optimal maps between Gaussians in high dimensions;
  • notebooks/W2GN_latent_space_optimal_transport.ipynb -- latent space optimal transport for a CelebA 64x64 Aligned Images (use this script to rescale dataset to 64x64);
  • notebooks/W2GN_domain_adaptation.ipynb -- domain adaptation for MNIST-USPS digits datasets;
  • notebooks/W2GN_color_transfer.ipynb -- cycle monotone pixel-wise image-to-image color transfer (example images are provided in data/color_transfer/);
  • notebooks/W2GN_style_transfer.ipynb -- cycle monotone image dataset-to-dataset style transfer (used datasets are publicitly available at the official CycleGan repo);

Input convex neural networks

  • src/icnn.py -- modules for Input Convex Neural Network architectures (DenseICNN, ConvICNN);

Results

Toy Experiments

Transforming single Gaussian to the mixture of 100 Gaussuans without mode dropping/collapse (and some other toy cases).

Optimal Transport Maps between High Dimensional Gaussians

Assessing the quality of fitted optimal transport maps between two high-dimensional Gaussians (tested in dim up to 4096). The metric is Unexplained Variance Percentage (UVP, %).

2 4 8 16 32 64 128 256 512 1024 2048 4096
Large-scale OT <1 3.7 7.5 14.3 23 34.7 46.9 >50 >50 >50 >50 >50
Wasserstein-2 GN <1 <1 <1 <1 <1 <1 1 1.1 1.3 1.7 1.8 1.5

Latent Space Optimal Transport

CelebA 64x64 generated faces. The quality of the model highly depends on the quality of the autoencoder. Use notebooks/AE_Celeba.ipynb to train MSE or perceptual AE (on VGG features, to improve AE visual quality).
Pre-trained autoencoders: MSE-AE [Goodle Drive, Yandex Disk], VGG-AE [Google Drive, Yandex Disk].

Combining simple pre-trained MSE autoencoder with W2GN is enough to surpass Wasserstein GAN model in Freschet Inception Distance Score (FID).

AE Reconstruct AE Raw Decode AE + W2GN WGAN
FID Score 23.35 86.66 43.35 45.23

Perceptual VGG autoencoder combined with W2GN provides nearly State-of-the-art FID (compared to Wasserstein GAN with Quadratic Cost).

AE Reconstruct AE Raw Decode AE + W2GN WGAN-QC
FID Score 7.5 31.81 17.21 14.41

Image-to-Image Color Transfer

Cycle monotone color transfer is applicable even to gigapixel images!

Domain Adaptation

MNIST-USPS domain adaptation. PCA Visualization of feature spaces (see the paper for metrics).

Unpaired Image-to-Image Style Transfer

Optimal transport map in the space of images. Photo2Cezanne and Winter2Summer datasets are used.

About

PyTorch implementation of the paper "Wasserstein-2 Generative Networks".


Languages

Language:Jupyter Notebook 99.5%Language:Python 0.5%