diffusion-hyperfeatures / diffusion_hyperfeatures

Official PyTorch Implementation for Diffusion Hyperfeatures, NeurIPS 2023

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Diffusion Hyperfeatures: Searching Through Time and Space for Semantic Correspondence

This repository contains the code accompanying the paper Diffusion Hyperfeatures: Searching Through Time and Space for Semantic Correspondence. The code implements Diffusion Hyperfeatures, a framework for consolidating multi-scale and multi-timestep feature maps from a diffusion model into per-pixel feature descriptors.

teaser

Releases

  • ๐Ÿš€ 2024/02/17 - Added generic training code for how one might train Diffusion Hyperfeatures for things beyond semantic correspondence.
  • ๐Ÿš€ 2023/09/28 - Added training code for Diffusion Hyperfeatures.
  • ๐Ÿš€ 2023/07/07 - Added extraction code and demos for real / synthetic images for Diffusion Hyperfeatures.

Setup

This code was tested with Python 3.8. To install the necessary packages, please run:

conda env create -f environment.yml
conda activate dhf

Pretrained Networks

You can download the following pretrained aggregation networks by running download_weights.sh.

Extraction

To extract and save Diffusion Hyperfeatures for your own set of real images, or a set of synthetic images with your own custom prompts, run extract_hyperfeatures.py.

To run on real images, you can provide a folder of images with or without corresponding annotations.

python3 extract_hyperfeatures.py --save_root hyperfeatures --config_path configs/real.yaml --image_root assets/spair/images --images_or_prompts_path annotations/spair_71k_test-6.json 

python3 extract_hyperfeatures.py --save_root hyperfeatures --config_path configs/real.yaml --image_root assets/spair/images --images_or_prompts_path ""

To run on synthetic images, you can provide a json file containing a list of prompts.

python3 extract_hyperfeatures.py --save_root hyperfeatures --config_path configs/synthetic.yaml  --image_root "" --images_or_prompts_path annotations/synthetic-3.json

Training

To train an aggregation network on your own custom dataset containing images and their labeled correspondences, use our training script provided below. We recommend first training a one-step model, which converges much faster but to a worse final performance than the multi-step model, to ensure that everything is working correctly. You can configure this with the save_timestep and num_timesteps fields in the config file. You can also configure the base model (SDv1-5 vs SDv2-1) with the model_id field. You can replace our pretrained weights when running the jupyter notebook with your own by replacing the weights_path field.

python3 train_hyperfeatures.py --config_path configs/train.yaml

Make sure to configure wandb for wandb logging.

wandb login

Make sure you also download the SPair-71k dataset if you want to use our default config.

mkdir datasets
wget -P datasets http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz
tar -xvf datasets/SPair-71k.tar.gz -C datasets

Semantic Keypoint Matching

We also provide demos for the semantic keypoint matching task using Diffusion Hyperfeatures.

For real images, real_demo waks through visualizing correspondences using either nearest neighbors or mutual nearest neighbors.

For synthetic images, synthetic_demo provides an interactive demo for visualizing correspondences given different prompts and different sets of user annotated source points.

Citing

@inproceedings{luo2023dhf,
  title={Diffusion Hyperfeatures: Searching Through Time and Space for Semantic Correspondence},
  author={Luo, Grace and Dunlap, Lisa and Park, Dong Huk and Holynski, Aleksander and Darrell, Trevor},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023}
}

Acknowledgements

Our codebase builds on top of a few prior works, including Deep ViT Features as Dense Visual Descriptors, Zero-Shot Category-Level Object Pose Estimation, Shape-Guided Diffusion, and ODISE.

About

Official PyTorch Implementation for Diffusion Hyperfeatures, NeurIPS 2023


Languages

Language:Jupyter Notebook 98.5%Language:Python 1.5%Language:Shell 0.0%