lweitkamp / GANs-JAX

Implementation of several Generative Adversarial Networks in JAX / Flax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Generative Adversarial Networks in JAX

This repository holds several notebooks that implement GANs in JAX using the Flax Linen package. All models are trained on Colab using the MNIST dataset on TPUs, with parallelization enabled by default.

Deep Convolutional GAN

The original GAN with architecture and other tips from the GANs for representation learning paper.

Wasserstein GAN with Penality

Training GANs is a notoriously difficult process. Even by carefully selecting the model architecture, training can still suffer due to mode collapse. The authors of the Wasserstein GAN paper argue the biggest problem is the way that the vanilla GAN learns a distribution; by switching to minimizing the earth mover distance we can alleviate this problem.

Conditional GAN

This is the logical next step after the vanilla GAN. If we do have labels, we should utilize them somehow. The Conditional GAN, as the name implies, conditions the output of the generator on the labels in addition to the noise. The discriminator in turn receives both the generated/real images and the label for classification.

InfoGAN

My personal favorite is the information-maximizing GAN. As the authors mention, because the info loss converges faster than the GAN loss, this addition basically comes for free. The result is a somewhat disentangled latent space where digits are easily separable. A great reference and interpretation of both the InfoGAN objective and the vanilla objective can be found here in Ferenc Huszár's blog.

About

Implementation of several Generative Adversarial Networks in JAX / Flax


Languages

Language:Jupyter Notebook 100.0%