weleen / fixed-point-diffusion-models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Contributors Forks Stargazers Issues

Fixed Point Diffusion Models

Project Page · Paper


DiT samples

Table of Contents

Roadmap

  • Code and paper release 🎉🎉
  • Jupyter notebook example
  • Pretrained model release (coming soon)
  • Code walkthrough and tutorial

Abstract

We introduce the Fixed Point Diffusion Model (FPDM), a novel approach to image generation that integrates the concept of fixed point solving into the framework of diffusion-based generative modeling. Our approach embeds an implicit fixed point solving layer into the denoising network of a diffusion model, transforming the diffusion process into a sequence of closely-related fixed point problems. Combined with a new stochastic training method, this approach significantly reduces model size, reduces memory usage, and accelerates training. Moreover, it enables the development of two new techniques to improve sampling efficiency: reallocating computation across timesteps and reusing fixed point solutions between timesteps. We conduct extensive experiments with state-of-the-art models on ImageNet, FFHQ, CelebA-HQ, and LSUN-Church, demonstrating substantial improvements in performance and efficiency. Compared to the state-of-the-art DiT model, FPDM contains 87% fewer parameters, consumes 60% less memory during training, and improves image generation quality in situations where sampling computation or time is limited.

Setup

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate DiT

Model

Our model definition, including all fixed point functionality, is included in models.py.

Training

Example training scripts:

# Standard model
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py

# Fixed Point Diffusion Model
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --fixed_point True --deq_pre_depth 1 --deq_post_depth 1

# With v-prediction and zero-SNR
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 1 --deq_post_depth 1

# With v-prediction and zero-SNR, with 4 pre- and post-layers
accelerate launch --config_file aconfs/1_node_1_gpu_ddp.yaml --num_processes 8 train.py --output_subdir v_pred_exp --predict_v True --use_zero_terminal_snr True --fixed_point True --deq_pre_depth 4 --deq_post_depth 4

Sampling

Example sampling scripts:

# Sample
python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --num_sampling_steps 20

# Sample with fewer iterations per timestep and more timesteps
python sample.py --ckpt {checkpoint-path-from-above} --fixed_point True --fixed_point_pre_depth 1 --fixed_point_post_depth 1 --cfg_scale 4.0 --fixed_point_iters 12 --num_sampling_steps 40 --fixed_point_reuse_solution True

Contribution

Pull requests are welcome!

Acknowledgements

  • The strong baseline from DiT:

    @article{Peebles2022DiT,
    title={Scalable Diffusion Models with Transformers},
    author={William Peebles and Saining Xie},
    year={2022},
    journal={arXiv preprint arXiv:2212.09748},
    }
    
  • The fast-DiT code from chuanyangjin:

  • All the great work from the CMU Locus Lab on Deep Equilibrium Models, which started with:

    @inproceedings{bai2019deep,
    author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
    title     = {Deep Equilibrium Models},
    booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
    year      = {2019},
    }
    
  • L.M.K. thanks the Rhodes Trust for their scholarship support.

About


Languages

Language:Python 100.0%