hiyouga / SAGAN-PyTorch

A PyTorch implementation for Self-Attention Generative Adversarial Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SAGAN-PyTorch

A PyTorch implementation for Goodfellow et al.'s ICML 2019 work "Self-Attention Generative Adversarial Networks". [arXiv] [PMLR]

GitHub

Requirements

Please install requirements by pip install -r requirements.txt

  • Python 3.7
  • numpy 1.17.2
  • torch 1.7.1
  • torchvision 0.8.2
  • Pillow 6.1.0

The following libraries are optional:

  • tensorboard
  • tqdm

Usage

Clone

git clone https://github.com/hiyouga/SAGAN-PyTorch.git
cd SAGAN-PyTorch

Train

python main.py --batch_size 64 --im_size 32 --dataset cifar10 --adv_loss wgan-gp

Results

Note that the wgan-gp loss is adopted by default, since we found that the wgan-gp loss performs much better than hinge loss in our experiments.

We use the CIFAR dataset as the unsupervised training set to generate images.

Real images

Generated images with wgan-gp loss

Fake images with wgan-gp loss

Generated images with hinge loss

Fake images with hinge loss

Click to show the training details (visualized via TensorBoard)

Attention weight and Inception score with wgan-gp loss

Attention weight and Inception score with wgan-gp loss

Attention weight and Inception score with hinge loss

Attention weight and Inception score with hinge loss

Loss curves with wgan-gp loss

Loss curves with wgan-gp loss

Loss curves with hinge loss

Loss curves with hinge loss

References

For SAGAN architecture:

  1. Zhang et al. Self-Attention Generative Adversarial Networks. ICML. 2019.
  2. https://github.com/heykeetae/Self-Attention-GAN
  3. https://github.com/christiancosgrove/pytorch-spectral-normalization-gan

For inception score:

  1. Salimans et al. Improved Techniques for Training GANs. NeurIPS. 2016.
  2. Shane Barratt and Rishi Sharma. A Note on the Inception Score. ICML Workshop on Theoretical Foundations and Applications of Deep Generative Models. 2018.
  3. https://github.com/sbarratt/inception-score-pytorch
  4. https://github.com/w86763777/pytorch-gan-metrics

License

MIT

About

A PyTorch implementation for Self-Attention Generative Adversarial Networks

License:MIT License


Languages

Language:Python 100.0%