This repo is the official implementation of "Efficient generative adversarial networks using linear additive-attention Transformers".
By Emilio Morales-Juarez and Gibran Fuentes-Pineda.
Although the capacity of deep generative models for image generation, such as Diffusion Models (DMs) and Generative Adversarial Networks (GANs), has dramatically improved in recent years, much of their success can be attributed to computationally expensive architectures. This has limited their adoption and use to research laboratories and companies with large resources, while significantly raising the carbon footprint for training, fine-tuning, and inference. In this work, we present LadaGAN, an efficient generative adversarial network that is built upon a novel Transformer block named Ladaformer. The main component of this block is a linear additive-attention mechanism that computes a single attention vector per head instead of the quadratic dot-product attention. We employ Ladaformer in both the generator and discriminator, which reduces the computational complexity and overcomes the training instabilities often associated with Transformer GANs. LadaGAN consistently outperforms existing convolutional and Transformer GANs on benchmark datasets at different resolutions while being significantly more efficient. Moreover, LadaGAN shows competitive performance compared to state-of-the-art multi-step generative models (e.g. DMs) using orders of magnitude less computational resources.
- Python 3.9
- Tensorflow <= 2.13.1
A conda environment can be created and activated with:
conda create --name tf13 python=3.9.16
conda activate tf13
pip install tensorflow==2.13.1 numpy matplotlib pillow scipy huggingface-hub
Use --file_pattern=<file_pattern>
and --eval_dir=<eval_dir>
to specify the dataset path and FID evaluation path.
python train.py --file_pattern=./data_path/*png --eval_dir=./eval_path/*png
Using a single 12GB GPU (RTX 3080 Ti) for CIFAR 10 and CelebA trainings:
Model (CIFAR 10 32x32) | ADM-IP (80 steps) | StyleGAN2 | VITGAN | LadaGAN |
---|---|---|---|---|
GPUs | Tesla V100 x 2 | - | - | RTX 3080 Ti x 1 |
#Images | 69M | - | - | 68M |
#Params | 57M | - | - | 19M |
FLOPs | 9.0B | - | - | 0.7B |
FID | 2.93 | 5.79 | 4.57 | 3.48 |
Model (CelebA 64x64) | ADM-IP (80 steps) | StyleGAN2 | VITGAN | LadaGAN |
---|---|---|---|---|
GPUs | Tesla V100 x 16 | - | - | RTX 3080 Ti x 1 |
#Images | 138M | - | - | 72M |
#Params | 295M | 24M | 38M | 19M |
FLOPs | 103.5B | 7.8B | 2.6B | 0.7B |
FID | 2.67 | - | 3.74 | 1.81 |
Model (FFHQ 128x128) | ADM-IP (80 steps) | StyleGAN2 | VITGAN | LadaGAN |
---|---|---|---|---|
#Images | 61M | - | - | 24M |
#Params | 543M | - | - | 24M |
FLOPs | 391.0B | 11.5B | 11.8B | 4.3B |
FID | 6.89 | - | - | 4.48 |
LadaGAN FID evaluation is computed using Pytorch FID.
Adjust hyperparameters in the config.py
file.
Implementation notes:
- This model depends on other files that may be licensed under different open source licenses.
- LadaGAN uses Differentiable Augmentation. Under BSD 2-Clause "Simplified" License.
- FID evaluation.
- Efficient patch generation with XLA.
Single head maps training progress:
@article{morales2024efficient,
title={Efficient generative adversarial networks using linear additive-attention Transformers},
author={Morales-Juarez, Emilio and Fuentes-Pineda, Gibran},
journal={arXiv preprint arXiv:2401.09596},
year={2024}
}
MIT