xinario / catgan_pytorch

Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

catGAN

PyTorch implementation of Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks that was originally proposed by Jost Tobias Springenberg.

Results on CIFAR10

Note that in this repo, only the unsupervised version was implemented for now. I reaplced the orginal architecture with DCGAN and the results are more colorful than the original one.

From 0 to 100 epochs:

cifar10

Prerequisites

  • Python 2.7
  • PyTorch v0.2.0
  • Numpy
  • SciPy
  • Matplotlib

Getting Started

Installation

  • Install PyTorh and the other dependencies
  • Clone this repo:
git clone https://github.com/xinario/catgan_pytorch.git
cd catgan_pytorch

Train

  • Download the cifar10 dataset (.png format from kaggle)
  • Create a dataset folder to hold the images
mkdir -p ./datasets/cifar10/images
  • Move the extracted images into the newly created folder

  • Train a model:

python catgan_cifar10.py --data_dir ./datasets/cifar10 --name cifar10

All the generated plot and samples can be found in side ./results/cifar10

Training options

optional arguments:

--continue_train  	to continue training from the latest checkpoints if --netG and --netD are not specified
--netG NETG           path to netG (to continue training)
--netD NETD           path to netD (to continue training)
--workers WORKERS     number of data loading workers
--num_epochs EPOCHS         number of epochs to train for

More options can be found in side the training script.

Acknowledgments

Some of code are inspired and borrowed from wgan-gp, DCGAN, catGAN chainer repo

About

Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks


Languages

Language:Python 100.0%