xuesong39 / DAC

[CVPR 2024] Official implementation of CVPR 2024 paper: "Doubly Abductive Counterfactual Inference for Text-based Image Editing"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Doubly Abductive Counterfactual Inference for Text-based Image Editing

This respository contains the code for the CVPR 2024 paper Doubly Abductive Counterfactual Inference for Text-based Image Editing.

Setup

Dependency Installation

First, clone the repository:

git clone https://github.com/xuesong39/DAC

Then, install the dependencies in a new virtual environment:

cd DAC
git clone https://github.com/huggingface/diffusers -b v0.24.0
cd diffusers
pip install -e .

Finally, cd in the main folder DAC and run:

pip install -r requirements.txt

Data Preparation

The images and annotations we use in the paper can be found here. For the format of data used in the experiments, we provide some examples in the folder DAC/data. For example, for the image DAC/data/cat/train/cat.jpeg, the folder containing source prompt is DAC/data/cat/ while that containing target prompt is DAC/data/cat-cap/.

Usage

Abduction-1

The fine-tuning script for abduction on U is train_text_to_image_lora.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="ORIGIN_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --rank=512 \
  --output_dir="U_PATH" \
  --validation_prompt="xxx" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat/"), --output_dir (e.g., "ckpt/cat"), and --validation_prompt (e.g., "A cat.").

Abduction-2

The fine-tuning script for abduction on Δ is train_text_to_image_lora_t.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="TARGET_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora_t.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --unet_lora_path="U_PATH" \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 --train_text_encoder \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --annealing=0.8 \
  --output_dir="DELTA_PATH" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat-cap/"), --unet_lora_path (e.g., "ckpt/cat"), and --output_dir (e.g., "ckpt/cat-cap-annealing0.8"). You can also change --annealing to achieve control on hyperparameter $\eta$.

Action & Prediction

The inference script is inference_t.sh as follows:

CUDA_VISIBLE_DEVICES=0 python inference_t.py \
 --annealing=0.8 \
 --unet_path="U_PATH" \
 --text_path="DELTA_PATH" \
 --target_prompt="xxx" \
 --save_path="./"

Please specify --unet_path (e.g., "ckpt/cat"), --text_path (e.g., "ckpt/cat-cap-annealing0.8"), and --target_prompt (e.g., "A cat wearing a wool cap.").

Optional Usage

This part contains the implementation mentioned in the ablation analysis section in the paper, i.e., ablation on Abduction-1. We could incorporate another exogenous variable T in the Abduction-1 to further improve fidelity.

Abduction-1

The fine-tuning script for abduction on U is the same as the above.

The fine-tuning script for abduction on T is train_text_to_image_lora_t.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="ORIGIN_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora_t.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --unet_lora_path="U_PATH" \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 --train_text_encoder \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --annealing=0.8 \
  --output_dir="T_PATH" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat/"), --unet_lora_path (e.g., "ckpt/cat"), and --output_dir (e.g., "ckpt/cat-annealing0.8")

Abduction-2

The fine-tuning script for abduction on Δ is train_text_to_image_lora_t2.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="TARGTE_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora_t2.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --unet_lora_path="U_PATH" \
  --text_lora1_path="T_PATH" \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 --train_text_encoder \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --annealing=0.8 \
  --output_dir="DELTA_PATH" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat-cap/"), --unet_lora_path (e.g., "ckpt/cat"), --text_lora1_path (e.g., "ckpt/cat-annealing0.8"), and --output_dir (e.g., "ckpt/cat-cap-annealing0.8-t2").

Action & Prediction

The inference script is inference_t2.sh as follows:

CUDA_VISIBLE_DEVICES=0 python inference_t2.py \
 --annealing=0.8 \
 --unet_path="U_PATH" \
 --text1_path="T_PATH" \
 --text2_path="DELTA_PATH" \
 --target_prompt="xxx" \
 --save_path="./"

Please specify --unet_path (e.g., "ckpt/cat"), --text1_path (e.g., "ckpt/cat-annealing0.8"), --text2_path (e.g., "ckpt/cat-cap-annealing0.8-t2"), and --target_prompt (e.g., "A cat wearing a wool cap.").

Checkpoints

We provide some checkpoints in the following:

Image Abduction-1 Abduction-2
DAC/data/cat U Δ
DAC/data/glass U Δ
DAC/data/black U Δ
DAC/data/cat U, T Δ
DAC/data/glass U, T Δ
DAC/data/black U, T Δ

Acknowledgments

In this code we refer to the following codebase: Diffusers and PEFT. Great thanks to them!

About

[CVPR 2024] Official implementation of CVPR 2024 paper: "Doubly Abductive Counterfactual Inference for Text-based Image Editing"


Languages

Language:Python 98.4%Language:Shell 1.6%