w86763777 / pytorch-gan-collections

PyTorch implementation of DCGAN, WGAN-GP and SNGAN.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Collections of GANs

Pytorch implementation of basic unsupervised GANs on CIFAR10.

For more defails about calculating Inception Score and FID using pytorch can be found here pytorch_gan_metrics.

Models

  • DCGAN
  • WGAN
  • WGAN-GP
  • SN-GAN

Requirements

  • Install python packages
    pip install -U pip setuptools
    pip install -r requirements.txt

Results

The FID is calculated by 50k generated images and CIFAR10 train set.

Model Dataset Inception Score FID
DCGAN CIFAR10 6.01(0.05) 42.72
WGAN(CNN) CIFAR10 6.62(0.09) 40.03
WGAN-GP(CNN) CIFAR10 7.66(0.10) 19.83
WGAN-GP(ResNet) CIFAR10 7.95(0.14) 16.95
SNGAN(CNN) CIFAR10 7.84(0.12) 17.81
SNGAN(ResNet) CIFAR10 8.31(0.10) 14.32

Examples

  • DCGAN

    dcgan_gif dcgan_png

  • WGAN(CNN)

    wgan_gif wgan_png

  • WGAN-GP(CNN)

    wgangp_cnn_gif wgangp_cnn_png

  • WGAN-GP(ResNet)

    wgangp_res_gif wgangp_res_png

  • SNGAN(CNN)

    sngan_cnn_gif sngan_cnn_png

  • SNGAN(ResNet)

    sngan_res_gif sngan_res_png

Reproduce

  • Download cifar10.train.npz for calculating FID. Then, create folder stats for the npz files

    stats
    └── cifar10.train.npz
    
  • Train from scratch

    Different methods are separated into different files for clear reading.

    # DCGAN
    python dcgan.py --flagfile ./configs/DCGAN_CIFAR10.txt
    # WGAN(CNN)
    python wgan.py --flagfile ./configs/WGAN_CIFAR10_CNN.txt
    # WGAN-GP(CNN)
    python wgangp.py --flagfile ./configs/WGANGP_CIFAR10_CNN.txt
    # WGAN-GP(ResNet)
    python wgangp.py --flagfile ./configs/WGANGP_CIFAR10_RES.txt
    # SNGAN(CNN)
    python sngan.py --flagfile ./configs/SNGAN_CIFAR10_CNN.txt
    # SNGAN(ResNet)
    python sngan.py --flagfile ./configs/SNGAN_CIFAR10_RES.txt

Learning Curves

inception_score_curve fid_curve

Change Log

  • 2022-01-10

    • Update pytorch to 1.10.1 and CUDA 11.3
    • Use pytorch_gan_metrics to calculate FID and Inception Score
    • Use 50k generated images and CIFAR10 train set to calculate FID
    • Fix default parameters especially for wgan.py
  • 2021-04-16

    • Update pytorch to 1.8.1
    • Move metrics to submodule.
    • Evaluate FID on CIFAR10 test set instead of training set.
    • Fix cifar10.test.npz download link and sample images.

About

PyTorch implementation of DCGAN, WGAN-GP and SNGAN.


Languages

Language:Python 100.0%