sumorday / EDGAN-PyTorch

Official implementation for Penalty Gradient Normalization for Generative Adversarial Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Penalty Gradient Normalization for Generative Adversarial Networks

Requirements

  • Python 3.8.9
  • Python packages
    # update `pip` for installing tensorboard.
    pip install -U pip setuptools
    pip install -r requirements.txt

Datasets

  • CIFAR-10

    Pytorch build-in CIFAR-10 will be downloaded automatically.

  • STL-10

    Pytorch build-in STL-10 will be downloaded automatically.

  • CelebA-HQ 128/256

    We obtain celeba-hq from this repository and preprocess it into lmdb file.

    • 256x256

      python dataset.py path/to/celebahq/256 ./data/celebahq/256
      
    • 128x128

      We split data into train test splits by filenames, the test set contains images from 27001.jpg to 30000.jpg.

      python dataset.py path/to/celebahq/128/train ./data/celebahq/128
      

    The folder structure:

    ./data/celebahq
    ├── 128
    │   ├── data.mdb
    │   └── lock.mdb
    └── 256
        ├── data.mdb
        └── lock.mdb
    
  • LSUN Church Outdoor 256x256 (training set)

    The folder structure:

    ./data/lsun/church/
    ├── data.mdb
    └── lock.mdb
    

Preprocessing Datasets for FID

Pre-calculated statistics for FID can be downloaded here:

  • cifar10.train.npz - Training set of CIFAR10
  • cifar10.test.npz - Testing set of CIFAR10
  • stl10.unlabeled.48.npz - Unlabeled set of STL10 in resolution 48x48
  • celebahq.3k.128.npz - Last 3k images of CelebA-HQ 128x128
  • celebahq.all.256.npz - Full dataset of CelebA-HQ 256x256
  • church.train.256.npz - Training set of LSUN Church Outdoor

Folder structure:

./stats
├── celebahq.3k.128.npz
├── celebahq.all.256.npz
├── church.train.256.npz
├── cifar10.test.npz
├── cifar10.train.npz
└── stl10.unlabeled.48.npz

NOTE

All the reported values (Inception Score and FID) in our paper are calculated by official implementation instead of our implementation.

Training

  • Configuration files

    • We use absl-py to parse, save and reload the command line arguments.

    • All the configuration files can be found in ./config.

    • The compatible configuration list is shown in the following table:

      Script Configurations Multi-GPU
      train.py GN-GAN_CIFAR10_CNN.txt
      GN-GAN_CIFAR10_RES.txt
      GN-GAN_CIFAR10_BIGGAN.txt
      GN-GAN_STL10_CNN.txt
      GN-GAN_STL10_RES.txt
      GN-GAN-CR_CIFAR10_CNN.txt
      GN-GAN-CR_CIFAR10_RES.txt
      GN-GAN-CR_CIFAR10_BIGGAN.txt
      GN-GAN-CR_STL10_CNN.txt
      GN-GAN-CR_STL10_RES.txt
      train_ddp.py GN-GAN_CELEBAHQ128_RES.txt
      GN-GAN_CELEBAHQ256_RES.txt
      GN-GAN_CHURCH256_RES.txt
      ✔️
  • Run the training script with the compatible configuration, e.g.,

    • train.py supports training gan on CIFAR10 and STL10, e.g.,
      python train.py \
          --flagfile ./config/GN-GAN_CIFAR10_RES.txt
    • train_ddp.py is optimized for multi-gpu training, e.g.,
      CUDA_VISIBLE_DEVICES=0,1,2,3 python train_ddp.py \
          --flagfile ./config/GN-GAN_CELEBAHQ256_RES.txt
      
  • Generate images from checkpoints, e.g.,

    --eval: evaluate best checkpoint.

    --save PATH: save the generated images to PATH

    python train.py \
        --flagfile ./logs/GN-GAN_CIFAR10_RES/flagfile.txt \
        --eval \
        --save path/to/generated/images
    

How to integrate Gradient Normalization into your work?

The function normalize_gradient is implemented based on torch.autograd module, which can easily normalize your forward propagation of discriminator by updating a single line.

from torch.nn import BCEWithLogitsLoss
from models.gradnorm import penalty_normalize_gradient

net_D = ...     # discriminator
net_G = ...     # generator
loss_fn = BCEWithLogitsLoss()

# Update discriminator
x_real = ...                                    # real data
x_fake = net_G(torch.randn(64, 3, 32, 32))      # fake data
pred_real = penalty_normalize_gradient(net_D, x_real)   # net_D(x_real)
pred_fake = penalty_normalize_gradient(net_D, x_fake)   # net_D(x_fake)
loss_real = loss_fn(pred_real, torch.ones_like(pred_real))
loss_fake = loss_fn(pred_fake, torch.zeros_like(pred_fake))
(loss_real + loss_fake).backward()              # backward propagation
...

# Update generator
x_fake = net_G(torch.randn(64, 3, 32, 32))      # fake data
pred_fake = penalty_normalize_gradient(net_D, x_fake)   # net_D(x_fake)
loss_fake = loss_fn(pred_fake, torch.ones_like(pred_fake))
loss.backward()                                 # backward propagation
...

Citation

If you find our work is relevant to your research, please cite:


About

Official implementation for Penalty Gradient Normalization for Generative Adversarial Networks


Languages

Language:Python 100.0%