berryxue / PixArt-sigma

PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation

Home Page:https://pixart-alpha.github.io/PixArt-sigma-project/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

👉 PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation


This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our paper exploring Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. You can find more visualizations on our project page.

PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation
Junsong Chen*, Chongjian Ge*, Enze Xie*†, Yue Wu*, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, Zhenguo Li
Huawei Noah’s Ark Lab, DLUT, HKU, HKUST


Welcome everyone to contribute🔥🔥!!

Learning from the previous PixArt-α project, we will try to keep this repo as simple as possible so that everyone in the PixArt community can use it.


Breaking News 🔥🔥!!

  • (🔥 New) Apr. 6, 2024. 💥 PixArt-Σ checkpoint 256px & 512px are released!
  • (🔥 New) Mar. 29, 2024. 💥 PixArt-Σ training & inference code & toy data are released!!!

🔧 Dependencies and Installation

conda create -n pixart python==3.9.0
conda activate pixart
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

git clone https://github.com/PixArt-alpha/PixArt-sigma.git
cd PixArt-sigma
pip install -r requirements.txt

🔥 How to Train

1. PixArt Training

First of all.

We start a new repo to build a more user friendly and more compatible codebase. The main model structure is the same as PixArt-α, you can still develop your function base on the original repo. lso, This repo will support PixArt-alpha in the future.

Now you can train your model without prior feature extraction. We reform the data structure in PixArt-α code base, so that everyone can start to train & inference & visualize at the very beginning without any pain.

1.1 Downloading the toy dataset

Download the toy dataset first. The dataset structure for training is:

cd ./pixart-sigma-toy-dataset

Dataset Structure
├──InternImgs/  (images are saved here)
│  ├──000000000000.png
│  ├──000000000001.png
│  ├──......
├──InternData/
│  ├──data_info.json    (meta data)
Optional(👇)
│  ├──img_sdxl_vae_features_1024resolution_ms_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npy
│  │  ├──000000000001.npy
│  │  ├──......
│  ├──caption_features_new
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......
│  ├──sharegpt4v_caption_features_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......

1.2 Download pretrained chechpoint

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py

1.3 You are ready to train!

Selecting your desired config file from config files dir.

python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 \
          train_scripts/train.py \
          configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py \
          --load-from output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth
          --work-dir output/your_first_pixart-exp \
          --debug

💻 How to Test

1. Quick start with Gradio

To get started, first install the required dependencies. Make sure you've downloaded the checkpoint files from models(coming soon) to the output/pretrained_models folder, and then run on your local machine:

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py

# demo launch
python scripts/interface.py --model_path output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth --image_size 512 --port 11223

2. Integration in diffusers

(Coming soon)

💪To-Do List

We will try our best to release

  • Training code
  • Inference code
  • Model zoo
  • Diffusers
  • training & inference code of One Step Sampling with DMD

before 10th, April.

About

PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation

https://pixart-alpha.github.io/PixArt-sigma-project/

License:GNU Affero General Public License v3.0


Languages

Language:Python 75.5%Language:Jupyter Notebook 24.4%Language:Dockerfile 0.1%Language:CSS 0.0%