christiancosgrove / pytorch-spectral-normalization-gan

Paper by Miyato et al. https://openreview.net/forum?id=B1QRgziT-

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

sn-wgan results

ferrine opened this issue · comments

Hi, I'm trying to apply spectral normalization to wasserstein gans. I've failed to make it work in my project so tried your repository to get more intuition of how to train them. However I had no training progress for about a day of training.

In original WGAN paper they seem to use 25 epoches for training. I've waited for 130 so far and got the following results (with your code)
image

Some introspection in discriminator gave me interesting insights. Indeed WGAN is Lipshitz with constant ~2.4 according to gradient norm histograms. However sn-gan or regular gan seem to have gradients with larger average norm. Below are gradient norms for trained for 100 epoches sn-gan and sn-wgan and gan. On cifar dataset (and mnist for regular gan), I used 5 discriminator iterations per generator update.
image
image
image

Seems like sn-gan has better gradients for generator according to histograms. For regular GAN I see there are small gradients even for images from generator.

Did you manage to get satisfactory results for SN-WGAN and how if yes?

My current intuition says me that devil is in gradient or their bias (I don't think I have that biase as I use 1024 batch size). Convergence might be too slow because of these gradients.

One question... are you using the resnet model or the dcgan model?

This is very interesting...

A few months ago someone reported problems with sn-WGAN.

Have you tried using smaller batch sizes? I've seen that paper and am aware that WGAN has biased sample gradients... maybe it's not understood how this interacts with spectral norm?

I used resnet model to report the first picture and my implementation without resnet for histograms.
In my setup I always did 1 epoch discriminator pretraining. I tried small batch sizes and then moved to larger ones as I got poor results and found a paper about bias. Resnet model produced bad looking samples and did not seem to converge ever (using rmsprop for D, adam 0.5, 0.999 for G).

Loss function behaviour for resnet model was bad (that's why I tried to run your repo). I observed non stable high variance learning curves for fake and real critic scores (loss = fake_loss + real_loss, right signs are inside). There were no such effects for non resnet model. However in both cases I also observed divergence (|fake_loss - real_loss|) of critic scores for real and fake images. This divergence was only growing in time (same symptoms as in the link you provided). Since critic is Lipshitz I would expect it will only decrease in time (if I always have perfect critic). If Two distributions get closer, the the critic score cant get higher.

image
source

At the barycenter it seems to be zero, but when domains overlap there are a lot of solutions providing minimum loss and zero wasserstein distance. (Maybe we can somehow help our critic?)

In toy problems like fitting 1d Gaussian the behaviour was first growing and then decreasing generator loss (as well as divergence I mentioned). BUT! When I get too close to optima I have oscillations for discriminator scores that influence the divergence and not wasserstein distance estimate.

Symptoms I get in GANs makes me think either I

  • do not train Discriminator to convergence
  • or it heavily violates Lipshitz constraint (do not believe)
  • or Convergence is too slow, but why? These guys got it work in 25 epochs.
  • other theoretical problems that are not well studied