TracyCuiq / wgan_caffe

caffe_wgan

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Caffe implementation of WGAN ("Wasserstein GAN")

A few notes

  • In this implementation we add overtraining like it's described in original code at the lines 163-166 of main.py. These lines act only on the first 25 generator iterations or very sporadically (once every 500 generator iterations). In such a case, they set the number of iterations on the critic to 100 instead of the default 5. This helps to start with the critic at optimum even in the first iterations. There shouldn't be a major difference in performance, but it can help, especially when visualizing learning curves (since otherwise you'd see the loss going up until the critic is properly trained). This is also why the first 25 iterations take significantly longer than the rest of the training as well.
  • Like it's described in the original implementation in pytorch this we find a problem when the critic fails to be close to optimum, and hence it's error stops being a good Wasserstein estimate. Known causes are high learning rates and momentum, and anything that helps the critic get back on track is likely to help with the issue. We use several learning rates (0.00005, 0.00002, 0.00001), and we can verify this problem is

Prerequisites

  • Computer with Linux. We have use GNU/Linux Debian 9.0.
  • You need a special adaptation of Caffe library you can find this. This version includes two new functions used in WGAN implementation.
  • For training, an NVIDIA GPU with cuDNN library is strongly recommended for speed. CPU is supported but training is very slow. We have use CUDA v8.0 and cuDNN v7.0.

Reproducing LFW_FACES experiments

You need to download faces data in lw_faces

You can find caffe models in models folder.

First Step with learning rate 0.00005

./bin/wgan_release --run-wgan --log [log_file_name] --batch-size 64 \
--d-iters-by-g-iter 5 --main-iter 7800 --z-vector-bin-file [z_vector_file_name] --z-vector-size 100 \
--dataset LFW_faces --data-src-path ./bin/data/lfw_funneled --output-path [output_folder] \
--solver-d-model ./models/solver_d_lr_A.prototxt --solver-g-model ./models/solver_g_lr_A.prototxt

Second Step with learning rate 0.00002

./bin/wgan_release --run-wgan --log [log_file_name] --batch-size 64 \
--d-iters-by-g-iter 5 --main-iter 7800 --z-vector-bin-file [z_vector_file_name] --z-vector-size 100 \
--dataset LFW_faces --data-src-path ./bin/data/lfw_funneled --output-path [output_folder] \
--solver-d-model ./models/solver_d_lr_B.prototxt --solver-g-model ./models/solver_g_lr_B.prototxt \
--solver-d-state [output_folder]/wgan_d_iter_40000.solverstate --solver-g-state [output_folder]/wgan_g_iter_7500.solverstate

Third Step with learning rate 0.00001

./bin/wgan_release --run-wgan --log [log_file_name] --batch-size 64 \
--d-iters-by-g-iter 5 --main-iter 7800 --z-vector-bin-file [z_vector_file_name] --z-vector-size 100 \
--dataset LFW_faces --data-src-path ./bin/data/lfw_funneled --output-path [output_folder] \
--solver-d-model ./models/solver_d_lr_C.prototxt --solver-g-model ./models/solver_g_lr_C.prototxt \
--solver-d-state [output_folder]/wgan_d_iter_80000.solverstate --solver-g-state [output_folder]/wgan_g_iter_15500.solverstate

faces_generation

You can see the training evolution

Animation

loss Generator

Loss Generator Evolution

log(loss Discriminator)

Loss Discriminator Evolution

Reproducing CIFAR10 (airplanes) experiments

You need to download cifar10 data in binary mode

First Step with learning rate 0.00005

./bin/wgan_release --run-wgan --log [log_file_name] --batch-size 64 \
--d-iters-by-g-iter 5 --main-iter 7800 --z-vector-bin-file [z_vector_file_name] --z-vector-size 100 \
--dataset Cifar10 --data-src-path ./bin/data/lfw_funneled --output-path [output_folder] \
--solver-d-model ./models/solver_d_lr_A.prototxt --solver-g-model ./models/solver_g_lr_A.prototxt

Second Step with learning rate 0.00002

./bin/wgan_release --run-wgan --log [log_file_name] --batch-size 64 \
--d-iters-by-g-iter 5 --main-iter 7800 --z-vector-bin-file [z_vector_file_name] --z-vector-size 100 \
--dataset Cifar10 --data-src-path ./bin/data/lfw_funneled --output-path [output_folder] \
--solver-d-model ./models/solver_d_lr_B.prototxt --solver-g-model ./models/solver_g_lr_B.prototxt \
--solver-d-state [output_folder]/wgan_d_iter_40000.solverstate --solver-g-state [output_folder]/wgan_g_iter_7500.solverstate

Third Step with learning rate 0.00001

./bin/wgan_release --run-wgan --log [log_file_name] --batch-size 64 \
--d-iters-by-g-iter 5 --main-iter 7800 --z-vector-bin-file [z_vector_file_name] --z-vector-size 100 \
--dataset Cifar10 --data-src-path ./bin/data/lfw_funneled --output-path [output_folder] \
--solver-d-model ./models/solver_d_lr_C.prototxt --solver-g-model ./models/solver_g_lr_C.prototxt \
--solver-d-state [output_folder]/wgan_d_iter_80000.solverstate --solver-g-state [output_folder]/wgan_g_iter_15500.solverstate

cifar10_airplanes_generation

You can see the training evolution

Animation

loss Generator

Loss Generator Evolution

log(loss Discriminator)

Loss Discriminator Evolution

About

caffe_wgan

License:GNU General Public License v3.0


Languages

Language:C 79.2%Language:Shell 7.4%Language:C++ 7.0%Language:Cuda 4.1%Language:Makefile 2.3%