haoyuli02 / diffusion-ReTrac

[TMLR 2024] "Data Attribution for Diffusion Models: Timestep-induced Bias in Influence Estimation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Data Attribution for Diffusion Models: Timestep-induced Bias in Influence Estimation

arXiv

This is the official implementation of "Data Attribution for Diffusion Models: Timestep-induced Bias in Influence Estimation" (TMLR 2024)

Tong Xie*, Haoyu Li*, Andrew Bai, Cho-Jui Hsieh

[Paper] [OpenReview]


Table of Contents
  1. TL;DR
  2. Requirements
  3. Usage

TL;DR

Influence estimations for diffusion models can be highly dependent on training timesteps, introducing bias and arbitrariness in attribution results. We identify the dominating norm effect where this bias causes top influential samples to be the same across diverse test images (ie. generally influential). To address this, we present diffusion-ReTrac with re-normalization technique to provide fair and targeted attribution.

image

Setup

pip install requirments.txt

Usage

1. Model Training

The train.py file provides code to train a diffusion model on customized datasets (and save training timesteps / noise into checkpoints for attribution)

For example, to train on the CIFAR-MNIST dataset, run the following:

python3 train.py --gpu=0 --dataset='cifar_mnist' --learning_rate=0.0001 --num_epochs=500 --save_model_epoch=50 --train_batch_size=32 --resolution=32 --output_dir='trained_models/cifar_mnist' --samples_dir='trained_outputs/cifar_mnist' --loss_logs_dir="training_logs/cifar_mnist"

2. Diffusion-TracIn / ReTrac

The main.py file provides code to run Diffusion-TracIn / ReTrac, where the parameter --retrac controls whether ReTrac or TracIn is performed. The implementations are located in diffusion_tracin.py. For example, to run ReTrac on TinyImagenet, run the following:

python3 main.py --dataset='zh-plus/tiny-imagenet' --gpu=2 --ckpt_dir='trained_models/tiny_imagenet' --task='train' --retrac --interval=20 --save_path='influence/tiny_imagenet/retrac'

3. Generation

The generate.py file provides code to generate images from trained model checkpoints. Example usage:

python3 generate.py --gpu=0 --samples_dir="test_samples/gen" --resolution=128 --pretrained_model_path="path_to_ckpt" --eval_batch_size=32

Citation

If you find this project useful, please consider citing our paper:

@misc{xie2024dataattributiondiffusionmodels,
      title={Data Attribution for Diffusion Models: Timestep-induced Bias in Influence Estimation}, 
      author={Tong Xie and Haoyu Li and Andrew Bai and Cho-Jui Hsieh},
      year={2024},
      eprint={2401.09031},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2401.09031}, 
}

About

[TMLR 2024] "Data Attribution for Diffusion Models: Timestep-induced Bias in Influence Estimation"


Languages

Language:Python 100.0%