TGATE-V1: Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models
Wentian Zhang* Haozhe Liu1* Jinheng Xie2* Francesco Faccio1,3 Mike Zheng Shou2 Jรผrgen Schmidhuber1,31 AI Initiative, King Abdullah University of Science And Technology
2 Show Lab, National University of Singapore 3 The Swiss AI Lab, IDSIA
TGATE-V2: Faster Diffusion Through Temporal Attention Decomposition
Haozhe Liu1,4* Wentian Zhang* Jinheng Xie2* Francesco Faccio1,3 Mengmeng Xu4 Tao Xiang4 Mike Zheng Shou2 Juan-Manuel Pรฉrez-Rรบa4 Jรผrgen Schmidhuber1,31 AI Initiative, King Abdullah University of Science And Technology
2 Show Lab, National University of Singapore 3 The Swiss AI Lab, IDSIA 4 Meta
Code and Technical Report will be released soon!
We explore the role of the attention mechanism during inference in text-conditional diffusion models. Empirical observations suggest that cross-attention outputs converge to a fixed point after several inference steps. The convergence time naturally divides the entire inference process into two phases: an initial phase for planning text-oriented visual semantics, which are then translated into images in a subsequent fidelity-improving phase. Cross-attention is essential in the initial phase but almost irrelevant thereafter. Self-attention, however, initially plays a minor role but becomes increasingly important in the second phase. These findings yield a simple and training-free method called TGATE which efficiently generates images by caching and reusing attention outputs at scheduled time steps. Experiments show TGATEโs broad applicability to various existing text-conditional diffusion models which it speeds up by 10-50%.
The images generated by the diffusion model with or without TGATE. Our method can accelerate the diffusion model without generation performance drops. It is training-free and can be widely complementary to the existing studies.
- Training-Free.
- Easily Integrate into Existing Frameworks.
- Only a few lines of code are required.
- Friendly support CNN-based U-Net, Transformer, and Consistency Model
- 10%-50% speed up for different diffusion models.
-
2024/05/22: We have successfully extended TGATE to self-attention modules for greater acceleration! Stay tuned for a major update, which will be released in the coming weeks.
-
2024/04/14: We release TGATE v0.1.1 to support the
playground-v2.5-1024
model. -
2024/04/10: We release our package to PyPI. Check here for the usage.
-
2024/04/04: Technical Report is available on arxiv.
-
2024/04/04: TGATE for DeepCache (SD-XL) is released.
-
2024/03/30: TGATE for SD-1.5/2.1/XL is released.
-
2024/03/29: TGATE for LCM (SD-XL), PixArt-Alpha is released.
-
2024/03/28: TGATE is open source.
The images generated by the diffusion model at different denoising steps. The first row feeds the text embedding to the cross-attention modules for all steps. The second row only uses the text embedding from the first step to the 10th step, and the third row inputs the text embedding from the 11th to the 25th step.
We summarize our observations as follows:
-
Cross-attention converges early during the inference process, which can be characterized by a semantics-planning and a fidelity-improving stages. The impact of cross-attention is not uniform in these two stages.
-
The semantics-planning embeds text through cross-attention to obtain visual semantics.
-
The fidelity-improving stage improves the generation quality without the requirement of cross-attention. In fact, a null text embedding in this stage can improve performance.
- Step 1: TGATE caches the attention outcomes from the semantics-planning stage.
if gate_step == cur_step:
hidden_uncond, hidden_pred_text = hidden_states.chunk(2)
cache = (hidden_uncond + hidden_pred_text ) / 2
- Step 2: TGATE reuses them throughout the fidelity-improving stage.
if cross_attn and (gate_step<cur_step):
hidden_states = cache
Model | MACs | Param | Latency | Zero-shot 10K-FID on MS-COCO |
---|---|---|---|---|
SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 |
SD-1.5 w/ TGATE | 9.875T | 815.557M | 4.313s | 20.789 |
SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 |
SD-2.1 w/ TGATE | 22.208T | 815.433 M | 9.878s | 19.940 |
SD-XL | 149.438T | 2.570B | 53.187s | 24.628 |
SD-XL w/ TGATE | 84.438T | 2.024B | 27.932s | 22.738 |
Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 |
Pixart-Alpha w/ TGATE | 65.318T | 462.585M | 37.867s | 35.825 |
DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 |
DeepCache w/ TGATE | 43.868T | - | 14.666s | 23.999 |
LCM (SD-XL) | 11.955T | 2.570B | 3.805s | 25.044 |
LCM w/ TGATE | 11.171T | 2.024B | 3.533s | 25.028 |
LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733s | 36.086 |
LCM w/ TGATE | 7.623T | 462.585M | 4.543s | 37.048 |
The latency is tested on a 1080ti commercial card.
The MACs and Params are calculated by calflops.
The FID is calculated by PytorchFID.
- pytorch>=2.0.0
- diffusers>=0.27.2
- transformers==4.37.2
- DeepCache==0.1.1
- accelerate
To use TGATE for accelerating the denoising process, you can simply use main.py
. For example,
- SD-2.1 w/ TGATE: generate an image with the caption: "High quality photo of an astronaut riding a horse in space"
python main.py \
--prompt 'A coral reef bustling with diverse marine life.' \
--model 'sd_2.1' \
--gate_step 10 \
--saved_path './sd_2_1.png' \
--inference_step 25 \
- SD-XL w/ TGATE: generate an image with the caption: "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
python main.py \
--prompt 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k' \
--model 'sd_xl' \
--gate_step 10 \
--saved_path './sd_xl.png' \
--inference_step 25 \
- Pixart-Alpha w/ TGATE: generate an image with the caption: "An alpaca made of colorful building blocks, cyberpunk."
python main.py \
--prompt 'An alpaca made of colorful building blocks, cyberpunk.' \
--model 'pixart' \
--gate_step 8 \
--saved_path './pixart_alpha.png' \
--inference_step 25 \
- LCM-SDXL w/ TGATE: generate an image with the caption: "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
python main.py \
--prompt 'Self-portrait oil painting, a beautiful cyborg with golden hair, 8k' \
--model 'lcm_sdxl' \
--gate_step 1 \
--saved_path './lcm_sdxl.png' \
--inference_step 4 \
- SDXL-DeepCache w/ TGATE: generate an image with the caption: "A haunted Victorian mansion under a full moon."
python main.py \
--prompt 'A haunted Victorian mansion under a full moon.' \
--model 'sd_xl' \
--gate_step 10 \
--saved_path './sd_xl_deepcache.png' \
--inference_step 25 \
--deepcache \
-
For LCMs,
gate_step
is set as 1 or 2, andinference step
is set as 4. -
To use DeepCache,
deepcache
is set as True.
- T-GATE in ComfyUI ComfyUI_TGate
We encourage the users to read DeepCache and Adaptive Guidance
Methods | U-Net | Transformer | Consistency Model |
---|---|---|---|
DeepCache | โ | โ | - |
Adaptive Guidance | โ | โ | โ |
TGATE (Ours) | โ | โ | โ |
Compared with DeepCache:
- TGATE can cache one time and re-use the cached feature until ending sampling.
- TGATE is more friendly for Transformer-based Architecture and mobile devices since it drops the high-resolution cross-attention.
- TGATE is complementary to DeepCache.
Compared with Adaptive Guidance:
- TGATE can reduce the parameters in the second stage.
- TGATE can further improve the inference efficiency.
- TGATE is complementary to non-cfg framework, e.g. latent consistency model.
- TGATE is open source.
- We thank prompt to prompt and diffusers for the great code.
If you find our work inspiring or use our codebase in your research, please consider giving a star โญ and a citation.
@article{tgate,
title={Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models},
author={Zhang, Wentian and Liu, Haozhe and Xie, Jinheng and Faccio, Francesco and Shou, Mike Zheng and Schmidhuber, J{\"u}rgen},
journal={arXiv preprint arXiv:2404.02747},
year={2024}
}