tmralmeida / VGAN

Implementation of paper VDB.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

VGAN

PyTorch(v1.9.0) implementation of Variational Discriminator Bottleneck: improving Imitation Learning, Inverse RL, and GANs by constraining Information Flow for Deep Learning and GANs WASP course '21.

It has been tested on both RGB and grayscale datatypes through CIFAR-10 and FER-2013 datasets.

Requirements

Usage

Run train train.py with the respective options:

python train.py [-h] [--dataset  {CIFAR-10, FER-2013}] 
                [--model {VGAN,VGAN-GP}] [--batch_size BATCH_SIZE]
                [--num_workers NUM_WORKERS] [--epochs EPOCHS]
                [--lr_gen LR_GEN] [--lr_disc LR_DISC]
                [--ic IC] [--beta BETA]           
                [--alpha ALPHA] [--save_dir SAVE_DIR]   
                [--nimgs_save NIMGS_SAVE]                                        

For help on the optional arguments run: python train.py -h

Running: Training

python train.py --dataset CIFAR-10 --batch_size 128 --epochs 200

Running: Evaluation

python evaluate.py --dataset FER-13 --generator_path logs/FER-13/VGAN/gen.pth --discriminator_path logs/FER-13/VGAN/disc.pth

Results

FID score on 10k samples

Dataset/Method VGAN VGAN-GP
CIFAR-10 25.1 17.1
FER-13 90.5 80.6

Note: FID score based on the implementation of torchmetrics.

Example for FER-13 dataset

About

Implementation of paper VDB.

License:MIT License


Languages

Language:Python 100.0%