hichoe95 / PyTorch_WGAN_GP

Wasserstein GAN with gradient penalty tutoria, PyTorch ver.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PyTorch WGAN GP

This repository is only for training. Later, I will provide/upload pretrained weight.

Version

  • pytorch=1.4.0
  • pyyhon3.6
  • cuda10.0.x
  • cudnn7.6.3
  • environment.yaml should be used for reference only, since it has too many dependencies.

Dataset

Dataset CelebA HQ FFHQ(thumbnails)
size 1024 x 1024 128 x 128
# of images 30000 70000

(In training, CelebA HQ images are resized in 128 x 128 or 64 x 64)

You should change the directory name in data_loder.py.

Options and Help

wgan_gp$ python main.py -h
usage: main.py [-h] [--main_gpu MAIN_GPU] [--use_tensorboard USE_TENSORBOARD]
               [--checkpoint_dir CHECKPOINT_DIR] [--log_dir LOG_DIR]
               [--image_name IMAGE_NAME] [--train_data TRAIN_DATA]
               [--optim OPTIM] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]
               [--latent_dim LATENT_DIM]
               [--generator_upsample GENERATOR_UPSAMPLE]
               [--weight_init WEIGHT_INIT] [--norm_g NORM_G] [--norm_d NORM_D]
               [--nonlinearity NONLINEARITY] [--slope SLOPE]
               [--batch_size BATCH_SIZE] [--iter_num ITER_NUM]
               [--img_size IMG_SIZE] [--loss LOSS] [--n_critic N_CRITIC]
               [--lambda_gp LAMBDA_GP]

optional arguments:
  -h, --help            show this help message and exit
  --main_gpu MAIN_GPU   main gpu index
  --use_tensorboard USE_TENSORBOARD
                        Tensorboard
  --checkpoint_dir CHECKPOINT_DIR
                        full name is './checkponit'.format(main_gpu)
  --log_dir LOG_DIR     dir for tensorboard
  --image_name IMAGE_NAME
                        sample image name
  --train_data TRAIN_DATA
                        celeba or ffhq
  --optim OPTIM         Adam or RMSprop
  --lr LR               learning rate
  --beta1 BETA1         For Adam optimizer.
  --beta2 BETA2         For Adam optimizer.
  --latent_dim LATENT_DIM
                        dimension of latent vector
  --generator_upsample GENERATOR_UPSAMPLE
                        if False, using ConvTranspose.
  --weight_init WEIGHT_INIT
                        weight init from normal dist
  --norm_g NORM_G       inorm : instancenorm, bnorm : batchnorm, lnorm :
                        layernorm or None for Generator
  --norm_d NORM_D       inorm : instancenorm, bnorm : batchnorm, lnorm :
                        layernorm or None for discriminator(critic)
  --nonlinearity NONLINEARITY
                        relu or leakyrelu
  --slope SLOPE         if using leakyrelu, you can use this option.
  --batch_size BATCH_SIZE
                        size of the batches
  --iter_num ITER_NUM   number of iterations of training
  --img_size IMG_SIZE   size of each image dimension
  --loss LOSS           wgangp or bce, default is wgangp
  --n_critic N_CRITIC   number of training steps for discriminator per iter
  --lambda_gp LAMBDA_GP
                        amount of gradient penalty loss

Training

When I use 'RMSprop', it has the best training performance.

wgan_gp$ python main.py --main_gpu 4 \
                        --log_dir gpu4 \
                        --train_data celeba
                        --latent_dim 128 \
                        --image_name gpu_4.png \
                        --batch_size 32 \
                        --n_critic 5 \
                        --lr 0.00005 \
                        --lambda_gp 10 \
                        --optim RMSprop \
                        --generator_upsample True \
                        --norm_g bnorm \
                        --norm_d lnorm \
                        --nonlinearity leakyrelu \
                        --slope 0.2 \
                        --loss wgangp

Results - CelebA HQ

Tensorboard

To open tensorbaord window in local, if you run it on remote server, you should follow this command in local.

ssh -NfL localhost:8898:localhost:6009 [USERID]@[IP]

'8898' is arbitrary port number for local , and '6009' is arbitrary port number for remote.

About

Wasserstein GAN with gradient penalty tutoria, PyTorch ver.


Languages

Language:Python 100.0%