kpandey008 / wasserstein-gans

Implementation of Wasserstein Generative Adversarial Networks using Tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Wasserstein Generative Adversarial Networks (WGANS)

Implementation of the popular paper WASSERSTEIN GENERATIVE ADVERSARIAL NETWORK in Tensorflow

What is a WGAN ?

Wasserstein Generative Adversarial Network or WGAN is a recently introduced variant of the popular Generative Adversarial Networks or GAN's. It is commonly known in the machine learning community that GAN's are notoriously difficult to train (Don't believe me !!! Try training one..) and suffer from the following issues:

  1. Training the Generator and the Discriminator in GAN's is difficult. This is due to the nature of the Adversarial loss itself which induces a minmax game between the Generator and Discriminator networks such that one tries to fool the other until they reach a Nash equilibrium(at least theoretically !!) which is a state in which the Generator produces realistic samples which the Discriminator is unable to distinguish from the original samples. In practice this is very difficult to achieve using Gradient descent based training methods and results in model oscillation.

  2. The training of the GAN's is majorly dependent on the architecture of the Generator and the Discriminator. It is seen that the DCGAN's(https://github.com/kpandey008/DCGANS) which are the convolutional variants of GAN's tend to perform better than the traditional MLP based GAN's on several metrics (eg. Visual quality of samples, Better classification accuracy on some intrinsic tasks using the faetures learned by the Discriminator.)

  3. GAN's are highly susceptible to Mode Collapse. Mode Collapse is a scenario in which GAN falls to a mode and drops it's generated samples around that mode. (eg. generating only samples of a single digit out of many digits when trained over the MNIST dataset)

  4. GAN's might learn to draw samples which are visually pleasing but might lie outside the data manifold.(Seems like a paradox but this can happen!!)

The WGAN paper addresses some of these problems by using a distance measure called as Wasserstein distance or the Earth's Movers Distance. Specifically the paper addresses the following problems:

  1. The Wasserstein distance is shown to be a better metric than other distance metrics like Jensen-Shannon Distance or the Total Variation Distance. Use of Wasserstein loss stablizes the GAN training. The authors claim that the use of this metric removes the dependence of GAN training to network architectural constraints.

  2. Wasserstein provides a stable loss metric that can be used to assess the performance of the GAN. Until now there was no definite measure to assess the training performance of GAN's apart from manally inspecting the quality of samples generated by the Generator.

Diving into the code

Prerequisites

The following python packages must be installed for running the code

  • Python 2.7 or Python 3.3+
  • Tensorflow 0.12.1
  • Numpy
  • Matplotlib
  • ImageIO
  • Scikit-learn

I prefer to use Google Collaboratory for training such systems due to heavy computational requirements. Here is the link to an excellent Medium Post for setting up Google Colab with Drive to manage your ML Projects: https://medium.com/deep-learning-turkey/google-colab-free-gpu-tutorial-e113627b9f5d

Running the code

You can find the code for the WGAN in the jupyter notebook WGAN.ipynb. Running it should be fairly simple as the code is well-commented and explained fairly well After running the code you should expect to see the following directory structure in your current working directory

|--model_data_dcgan
     |--experiment_
         |--tmp/
         |--checkpoints/
         |--gifs/
         |--config.yaml

The tmp/ folder contains the image generated by the Generator at every 100th training step given a fixed noise vector (refer to the code for this part)

The checkpoints/ folder basically checkpoints your training (:P) so that you can resume if in case the training ends abruptly.

The gifs folder combines the images in tmp/ folder to create a visualization over time representing the generator image generation over the duration of training

The file config.yaml stores the configuration of the Generator and the Discriminator neural networks for that particular experiment

Results

Here is the visualization of network learning over 10000 training steps

The MNIST samples generated by the network are as shown below:

These outputs were obtained by training the GAN with convolutional Generator and Discriminator with Wasserstein Loss over 50000 time steps. You can find more information about the network hyperparameters used by looking into the jupyter notebook.

Note: As we can observe, the quality of generated samples is in comparison to DCGAN's (apart from a few samples) but the training is much stable than the latter. However I feel that the sample quality can be increased by training a network with larger capacity over more number of steps. Also presence of batchnorm in both the generator and the discriminator helps stabilize training to a great extent.

Conclusions

  1. The Wasserstein GAN indeeds provides a stable gradient during traning provided the Lipschitz constraint is enforced on the Discriminator network.(This is probably the key takeaway from the paper). For me the training sometimes suffered from the exploding gradient problem.

  2. I tested the WGAN's with multiple network architectures. The performance as well as the visual quality of the samples of the algorithm was not affected much. Thus the paper's claim of architectural robustness holds good.

  3. However the model training still showed oscillations for me. Maybe a bit of hyperparameter tuning needed(However in the paper is claimed that the training is quite robust to the hyperparameter tuning).

  4. Weight Clipping is probably not the best solution for enforcing the Lipschitz constraint on the Discriminator network. However, in this case it performs fairly well except in cases where the gradient explosion takes place.

  5. WGAN's do not suffer from mode collapse which can be observed in the above results). This is because of the fact that the discriminator is not saturated during the training stage ini comparison to the standard GAN networks which can suffer from considerable mode collapse.

Further Reading

Some good resources for knowing more about WGAN's and improving their training are :

  1. The original WGAN paper: https://arxiv.org/pdf/1701.07875.pdf
  2. The repo https://github.com/soumith/ganhacks provides a number of hacks to train GAN's in general in a stable setting.
  3. The paper Improved Training of Wasserstein GANs by Gulrajani et.al is an excellent resource for knowing more about WGAN training. The paper can be found at https://arxiv.org/pdf/1704.00028.pdf

Author

Kushagra Pandey / @kpandey008

About

Implementation of Wasserstein Generative Adversarial Networks using Tensorflow

License:MIT License


Languages

Language:Jupyter Notebook 97.8%Language:Python 2.2%