nieshaoshuai / prediction_gan

PyTorch Impl. of Prediction Optimizer (to stabilize GAN training)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Prediction Optimizer (to stabilize GAN training)

Introduction

This is a PyTorch implementation of 'prediction method' introduced in the following paper ...

  • Abhay Yadav et al., Stabilizing Adversarial Nets with Prediction Methods, ICLR 2018, Link

The authors proposed a simple (but effective) method to stabilize GAN trainings. With this Prediction Optimizer, you can easily apply the method to your existing GAN codes. This impl. is compatible with most of PyTorch optimizers and network structures. (Please let me know if you have any issues using this)

How-to-use

Instructions

  • Import prediction.py
    • from prediction import PredOpt
  • Initialize just like an optimizer
    • pred = PredOpt(net.parameters())
  • Run the model in a 'with' block to get results from a model with predicted params.
    • With 'step' argument, you can control lookahead step size (1.0 by default)
    • with pred.lookahead(step=1.0):
          output = net(input)
  • Call step() after an update of the network parameters
    • optim_net.step()
      pred.step()

Samples

  • You can find a sample code in this repository (example_gan.py)
  • A sample snippet
  • import torch.optim as optim
    from prediction import PredOpt
    
    
    # ...
    
    optim_G = optim.Adam(netG.parameters(), lr=0.01)
    optim_D = optim.Adam(netD.parameters(), lr=0.01)
    
    pred_G = PredOpt(netG.parameters())             # Create an prediction optimizer with target parameters
    pred_D = PredOpt(netD.parameters())
    
    
    for i, data in enumerate(dataloader, 0):
        # (1) Training D with samples from predicted generator
        with pred_G.lookahead(step=1.0):            # in the 'with' block, the model works as a 'predicted' model
            fake_predicted = netG(Z)                           
        
            # Compute gradients and loss 
        
            optim_D.step()
        
        
        # (2) Training G        

       with pred_D.lookahead(step=1.0:) # 'Predicted D'            fake = netG(Z)                         # Draw samples from the real model. (not predicted one) D_outs = netD(fake)

        # Compute gradients and loss
    
        optim_G.step()
        pred_G.step()                           # You should call PredOpt.step() after each update
``` 

Output samples

You can find more images at sanghoon#3 and sanghoon#4

Training w/ large learning rate (0.01)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 01 ep25_cifar_pred_lr 0 01
ep25_celeba_base_lr 0 01 ep25_celeba_pred_lr 0 01

Training w/ medium learning rate (1e-4)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 0001 ep25_cifar_pred_lr 0 0001
ep25_celeba_base_lr 0 0001 ep25_celeba_pred_lr 0 0001

Training w/ small learning rate (1e-5)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 00001 ep25_cifar_pred_lr 0 00001
ep25_celeba_base_lr 0 00001 ep25_celeba_pred_lr 0 00001

External links

TODOs

  • : Impl. as an optimizer
  • : Support pip install
  • : Add some experimental results

About

PyTorch Impl. of Prediction Optimizer (to stabilize GAN training)


Languages

Language:Python 98.3%Language:Shell 1.7%