markdtw / least-squares-gan

GAN that adopts the least squares loss function for the discriminator in tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Least Squares Generative Adversarial Networks

Tensorflow implementation of Least Squares Generative Adversarial Networks by Mao et al (LSGAN).

Prerequisites

Data

Preparation

  1. Clone this repo, create ckpt/ folder:

    git clone https://github.com/markdtw/least-squares-gan.git
    cd least-squares-gan
    mkdir ckpt
  2. To train on LSUN, use the provided tools to download and extract. For example:

    python download.py -c conference_room
    unzip conference_room_train_lmdb.zip
    python data.py export conference_room_train_lmdb --out_dir conference_room_train_images --flat

    I replaced .webp from this line to .jpg

  3. To train on CelebA, I use this file to download. Shout out to carpedm20.

  4. Now you are good to go, first time training on LSUN will center-crop all the images to 224x224 and store them in a new folder.

Train

Train on LSUN conference room with default settings:

python main.py --train

Train on CelebA with default settings:

python main.py --train --dataset=CelebA

Train from a previous checkpoint at epoch X:

python main.py --train --modelpath=ckpt/lsgan-LSUN<CelebA>-X

Check out tunable hyper-parameters:

python main.py

Some results

Epoch 10: ep-10

Epoch 25: ep-25

Epoch 45: ep-45

Results from epoch 45 is already nice and crispy.

Generator loss: g-loss

Discriminator loss: d-loss

Notes

  • The model will save 40 generated pictures in log/ folder every epoch.
  • Initialization is important! Default initialization with tf.xavier_initializer will lead to either D or G's gradient vanishing problem, instead I use tf.truncated_normal_initializer which is identical to DCGAN original implementation to solve the problem.
  • Issues are more than welcome!

Resources

About

GAN that adopts the least squares loss function for the discriminator in tensorflow


Languages

Language:Python 100.0%