This is the official Python
implementation of the paper Wasserstein-2 Generative Networks (preprint on arXiv) by Alexander Korotin, Vahe Egizarian, Arip Asadulaev, Alexander Safin and Evgeny Burnaev.
The repository contains reproducible PyTorch
source code for computing optimal transport maps (and distances) in high dimensions via the end-to-end non-minimax method (proposed in the paper) by using input convex neural networks. Examples are provided for various real-world problems: color transfer, latent space mass transport, domain adaptation, style transfer.
The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Main prerequisites are:
- pytorch
- torchvision
- CUDA + CuDNN
All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/
). For convenience, the majority of the evaluation output is preserved. Auxilary source code is moved to .py
modules (src/
).
notebooks/W2GN_toy_experiments.ipynb
-- toy experiments (2D: Swiss Roll, 100 Gaussuans, ...);notebooks/W2GN_gaussians_high_dimensions.ipynb
-- optimal maps between Gaussians in high dimensions;notebooks/W2GN_latent_space_optimal_transport.ipynb
-- latent space optimal transport for a CelebA 64x64 Aligned Images (use this script to rescale dataset to 64x64);notebooks/W2GN_domain_adaptation.ipynb
-- domain adaptation for MNIST-USPS digits datasets;notebooks/W2GN_color_transfer.ipynb
-- cycle monotone pixel-wise image-to-image color transfer (example images are provided indata/color_transfer/
);notebooks/W2GN_style_transfer.ipynb
-- cycle monotone image dataset-to-dataset style transfer (used datasets are publicitly available at the official CycleGan repo);
src/icnn.py
-- modules for Input Convex Neural Network architectures (DenseICNN, ConvICNN);
Transforming single Gaussian to the mixture of 100 Gaussuans without mode dropping/collapse (and some other toy cases).
Assessing the quality of fitted optimal transport maps between two high-dimensional Gaussians (tested in dim up to 4096). The metric is Unexplained Variance Percentage (UVP, %).
2 | 4 | 8 | 16 | 32 | 64 | 128 | 256 | 512 | 1024 | 2048 | 4096 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Large-scale OT | <1 | 3.7 | 7.5 | 14.3 | 23 | 34.7 | 46.9 | >50 | >50 | >50 | >50 | >50 |
Wasserstein-2 GN | <1 | <1 | <1 | <1 | <1 | <1 | 1 | 1.1 | 1.3 | 1.7 | 1.8 | 1.5 |
CelebA 64x64 generated faces. The quality of the model highly depends on the quality of the autoencoder. Use notebooks/AE_Celeba.ipynb
to train MSE or perceptual AE (on VGG features, to improve AE visual quality).
Pre-trained autoencoders: MSE-AE [Goodle Drive, Yandex Disk], VGG-AE [Google Drive, Yandex Disk].
Combining simple pre-trained MSE autoencoder with W2GN is enough to surpass Wasserstein GAN model in Freschet Inception Distance Score (FID).
AE Reconstruct | AE Raw Decode | AE + W2GN | WGAN | |
---|---|---|---|---|
FID Score | 23.35 | 86.66 | 43.35 | 45.23 |
Perceptual VGG autoencoder combined with W2GN provides nearly State-of-the-art FID (compared to Wasserstein GAN with Quadratic Cost).
AE Reconstruct | AE Raw Decode | AE + W2GN | WGAN-QC | |
---|---|---|---|---|
FID Score | 7.5 | 31.81 | 17.21 | 14.41 |
Cycle monotone color transfer is applicable even to gigapixel images!
MNIST-USPS domain adaptation. PCA Visualization of feature spaces (see the paper for metrics).
Optimal transport map in the space of images. Photo2Cezanne and Winter2Summer datasets are used.