jinzhuwei / DiM

Diffusion Image Mamba

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Install

Install Mamba-ssm

Refer to Vision Mamba

# Install causal_conv1d and mamba
pip install -e causal_conv1d>=1.1.0
pip install -e mamba-1p1p1

Train

# Assign HF_HOME to the cache directory
export HF_HOME="/comp_robot/rentianhe/caohe/cache"

# Train with 4 GPUs
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train_mamba.py --model DiM-S/2 --feature-path /shared_space/caohe/DATA/imagenet1k/train_vae --lr 5e-4
  • feature-path: the path to the pre-extracted features (use SD-VAE compression rate as 8)
  • model: the model name (Now supported DiM-[S|M|L|XL]/[2|4|8])
  • [default] batch size : global batch size = 256 (constant for all models, suggested in DiT)
  • [optional] --lr : learning rate (default 1e-4, for mamba, 5e-4 is suggested)
  • [optional] --ckpt-every: default to 50_000 (= 10 epochs)

The ckpt will be saved under results/xxx-MODEL_VERSION/checkpoints/

Evaluation

Sampling scripts

# Assign HF_HOME to the cache directory
export HF_HOME="/comp_robot/rentianhe/caohe/cache"
# 
MODEL_VERSION=005-DiM-S-2 # The model version
CKPT=0100000 # The checkpoint number
torchrun --nnodes=1 --nproc_per_node=4 sample_ddp.py --model DiM-S/2 --num-fid-samples 50000 --ckpt results/$MODEL_VERSION/checkpoints/$CKPT.pt --sample-dir samples/$MODEL_VERSION-$CKPT --per-proc-batch-size 64

FID evaluation

First need to install related packages following ADM's TensorFlow evaluation suite

IN_50K_REFERENCE=/shared_space/caohe/DATA/imagenet1k/VIRTUAL_imagenet256_labeled.npz
# ! adapt the path to the sample features path
python evaluations/evaluator.py $IN_50K_REFERENCE samples/DiM-B-2-ckpt-0100000/DiM-B-2-0100000-size-256-vae-ema-cfg-1.5-seed-0.npz

About

Diffusion Image Mamba

License:Other


Languages

Language:Python 72.4%Language:Cuda 18.2%Language:C++ 7.4%Language:Jupyter Notebook 1.2%Language:C 0.7%