gohyojun15 / ANT_diffusion

[Neurips 2023] Official pytorch implementation of "Addressing Negative Transfer in Diffusion Models"

Home Page:https://gohyojun15.github.io/ANT_diffusion/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Neurips 2023] Addressing Negative Transfer in Diffusion Models.

Alt text Generated image from ANT-UW (DiT-L) with guidance scale 3.0

This repository contains the official PyTorch implementation of the following paper: "Addressing Negative Transfer in Diffusion Models" (Neurips 2023). To gain a better understanding of the paper, please visit our project page and paper on arXiv.

This code is to train DiT model with ANT on ImageNet dataset. Our implementation is based on DiT, LibMTL, NashMTL.

Updates

Install pre-requisites

$ pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
$ pip install -r requirements.txt

Training DiT with ANT.

  • You can train DiT-S model with ANT

    python train_ant_single_gpu.py \
    --data_path [Your Data Path] \
    --total_clusters [Your cluster size (Default: 5)] \
    --mtl_method [uw, nash, or pcgrad]
    
  • Train DiT-L model with ANT-UW using multiple-gpus.

    torchrun --nproc_per_node=[number of gpus] train_uw_multi_gpu.py \
    --data_path [Your Data Path] \
    --total_clusters [Your cluster size (Default: 5)]
    

Sampling DiT trained with ANT.

torchrun --nproc_per_node=[number of gpus] sample_ddp.py \
--model_config [Your config path (config/DiT-L.yaml, config/DiT-S.yaml)] \
--ckpt [Your ckpt path] \
--sample-dir [Your sample dir] 

Interval clustering

The example codes for interval clustering are shown in interval_clustering.py

Model card ($k$=5)

All models can be downloaded from OneDrive link

Model FID IS Precision Recall
DiT-L + ANT-UW (Multi-GPU) 5.695 186.661 0.811 0.491
DiT-S + Nash 44.65 33.48 0.4209 0.5272
DiT-S + UW 48.40 30.84 0.4196 0.5172

About

[Neurips 2023] Official pytorch implementation of "Addressing Negative Transfer in Diffusion Models"

https://gohyojun15.github.io/ANT_diffusion/

License:MIT License


Languages

Language:Python 100.0%