mfouda / ASWD

Augmented Sliced Wasserstein Distances

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Augmented-Sliced-Wasserstein-Distances

This repository provides the code to reproduce the experimental results in the paper Augmented Sliced Wasserstein Distances.

Prerequisites

Python packages

To install the required python packages, run the following command:

pip install -r requirements.txt

Datasets

Two datasets are used in this repository, namely the CIFAR10 dataset and CELEBA dataset.

Precalculated Statistics

To calculate the Fréchet Inception Distance (FID score), precalculated statistics for datasets

are provided at: http://bioinf.jku.at/research/ttur/.

Project & Script Descriptions

Two experiments are included in this repository, where benchmarks are from the paper Generalized Sliced Wasserstein Distances and the paper Distributional Sliced-Wasserstein and Applications to Generative Modeling, respectively. The first one is on the task of sliced Wasserstein flow, and the second one is on generative modellings with GANs. For more details and setups, please refer to the original paper Augmented Sliced Wasserstein Distances.

Directories

  • ./result/ASWD/CIFAR/ contains generated imgaes trained with the ASWD on CIFAR10 dataset.
  • ./result/ASWD/CIFAR/fid/ FID scores of generated imgaes trained with the ASWD on CIFAR10 dataset are saved in this folder.
  • ./result/CIFAR/ model's weights and losses in the CIFAR10 experiment are stored in this directory.

Other setups follow the same naming rule.

Scripts

The sliced Wasserstein flow example can be found in the jupyter notebook.

The following scripts belong to the generative modelling example:

  • main.py : run this file to conduct experiments.
  • utils.py : contains implementations of different sliced-based Wasserstein distances.
  • TransformNet.py : edit this file to modify architectures of neural networks used to map samples.
  • experiments.py : functions for generating and saving randomly generated images.
  • DCGANAE.py : neural network architectures and optimization objective for training GANs.
  • fid_score.py : functions for calculating statistics (mean & covariance matrix) of distributions of images and the FID score between two distributions of images.
  • inception.py : download the pretrained InceptionV3 model and generate feature maps for FID evaluation.

Experiment options for the generative modelling example

The generative modelling experiment evaluates the performances of GANs trained with different sliced-based Wasserstein metrics. To train and evaluate the model, run the following command:

python main.py  --model-type ASWD --dataset CIFAR --epochs 200 --num-projection 1000 --batch-size 512 --lr 0.0005

Basic parameters

  • --model-type type of sliced-based Wasserstein metric used in the experiment, available options: ASWD, DSWD, SWD, MSWD, GSWD. Must be specified.
  • --dataset select from: CIFAR, CELEBA, default as CIFAR.
  • --epochs training epochs, default as 200.
  • --num-projection number of projections used in distance approximation, default as 1000.
  • --batch-size batch size for one iteration, default as 512.
  • --lr learning rate, default as 0.0005.

Optional parameters

  • --niter number of iteration, available for the ASWD, MSWD and DSWD, default as 5.
  • --lam coefficient of regularization term, available for the ASWD and DSWD, default as 0.5.
  • --r parameter in the circular defining function, available for GSWD, default as 1000.

References

Code

The code of generative modelling example is based on the implementation of DSWD by VinAI Research.

The pytorch code for calculating the FID score is from https://github.com/mseitzer/pytorch-fid.

Papers

About

Augmented Sliced Wasserstein Distances


Languages

Language:Python 72.9%Language:Jupyter Notebook 27.1%