A PyTorch implementation of VITGAN: Training GANs with Vision Transformers
- Use vectorized L2 distance in attention for Discriminator
- Overlapping Image Patches
- DiffAugment
- Self-modulated LayerNorm (SLN)
- Implicit Neural Representation for Patch Generation
- ExponentialMovingAverage (EMA)
- Balanced Consistency Regularization (bCR)
- Improved Spectral Normalization (ISN)
- Equalized Learning Rate
- Weight Modulation
- Python3
- einops
- pytorch_ema
- stylegan2-pytorch
- tensorboard
- wandb
pip install einops git+https://github.com/fadel/pytorch_ema stylegan2-pytorch tensorboard wandb
Train the model with the proposed parameters:
python main.py
Tensorboard
tensorboard --logdir runs/
The following parameters are the parameters, proposed in the paper for the CIFAR-10 dataset:
python main.py
The Generator follows the following architecture:
For debugging purposes, the Generator is separated into a Vision Transformer (ViT) model and a SIREN model.
Given a seed, the dimensionality of which is controlled by latent_dim
, the ViT model creates an embedding for each of the patches of the final image. Those embeddings are fed to a SIREN network, combined with a Fourier Position Encoding (Jupyter Notebook). It outputs the patches of the image, which are stitched together.
The ViT part of the Generator differs from a standard Vision Transformer in the following ways:
- The input to the Transformer consists only of the position embeddings
- Self-Modulated Layer Norm (SLN) is used in place of LayerNorm
- There is no classification head
SLN is the only place, where the seed is inputted to the network.
SLN consists of a regular LayerNorm, the result of which is multiplied by gamma
and added to beta
.
Both gamma
and beta
are calculated using a fully connected layer, different for each place, SLN is applied.
The input dimension to each of those fully connected is equal to hidden_dimension
and the output dimension is equal to hidden_dimension
.
A description of SIREN: [Blog Post] [Paper] [Colab Notebook]
In contrast to regular SIREN, the desired output is not a single image. For this purpose, the patch embedding is combined to a position embedding.
The positional encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: (Jupyter Notebook)
In my implementation, the input to the SIREN is the sum of a patch embedding and a position embedding.
Weight Modulation usually consists of a modulation and a demodulation module. After testing the network, I concluded that demodulation is not used in ViTGAN.
My implementation of the weight modulation is heavily based on CIPS. I have adjusted it to work for a fully-connected network, using a 1D convolution. The reason for using 1D convolution, instead of a linear layer is the groups term, which optimizes the performance by a factor of batch_size.
Each SIREN layer consists of a sinsin activation, applied to a weight modulation layer. The size of the input, the hidden and the output layers in a SIREN network could vary. Thus, in case the input size differs from the size of the patch embedding, I define an additional fully-connected layer, which converts the patch embedding to the appropriate size.
The Discriminator follows the following architecture:
The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications:
- DiffAugment
- Overlapping Image Patches
- Use vectorized L2 distance in attention for Discriminator
- Improved Spectral Normalization (ISN)
- Balanced Consistency Regularization (bCR)
For implementating DiffAugment, I used the code below:
[GitHub] [Paper]
Creation of the overlapping image patches is implemented with the use of a convolution layer.
[Paper]
The ISN implementation is based on the following implementation of Spectral Normalization:
[GitHub]
[Paper]
Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 [Paper]
SIREN: Implicit Neural Representations with Periodic Activation Functions
Vision Transformer: [Blog Post]
L2 distance attention: The Lipschitz Constant of Self-Attention
Spectral Normalization reference code: [GitHub] [Paper]
Diff Augment: [GitHub] [Paper]
Fourier Position Embedding: [Jupyter Notebook]
Exponential Moving Average: [GitHub]
Balanced Concictancy Regularization (bCR): [Paper]
SyleGAN2 Discriminator: [GitHub]