charlotte65 / PDAE

Official PyTorch implementation of PDAE (NeurIPS 2022)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unsupervised Representation Learning from Pre-trained Diffusion Probabilistic Models (PDAE)

This repository is official PyTorch implementation of PDAE (NeurIPS 2022).

@inproceedings{zhang2022unsupervised,
  title={Unsupervised Representation Learning from Pre-trained Diffusion Probabilistic Models},
  author={Zhang, Zijian and Zhao, Zhou and Lin, Zhijie},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}

Dataset

We use the LMDB ready-to-use datasets provided by Diff-AE (https://github.com/phizaz/diffae#lmdb-datasets).

The directory structure should be:

data
├─horse
|   ├─data.mdb
|   └lock.mdb
├─ffhq
|  ├─data.mdb
|  └lock.mdb
├─celebahq
|    ├─CelebAMask-HQ-attribute-anno.txt
|    ├─data.mdb
|    └lock.mdb
├─celeba64
|    ├─data.mdb
|    └lock.mdb
├─bedroom
|    ├─data.mdb
|    └lock.mdb

Download

pre-trained-dpms (required)

trained-models (optional)

You should put download in the root dicretory of this project and maintain their directory structure as shown in Google Drive.

Training

To train DDPM, run this command:

cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_regular.py --world_size 4

To train PDAE, run this command:

cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_representation_learning.py --world_size 4

To train a classifier for manipulation, run this command:

cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_manipulation_diffusion.py --world_size 4

To train a latent DPM, run this command:

cd ./trainer
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train_latent_diffusion.py --world_size 4

You can change the config file and run path in the python file.

Evaluation

cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 autoencoding_example.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 autoencoding_eval.py --world_size 4

PDAE achieves autoencoding reconstruction SOTA performance of SSIM(0.994) and MSE(3.84e-5) when using inferred $x_{T}$.

cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 denoise_one_step.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 gap_measure.py --world_size 4
cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 interpolation.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0 python3 manipulation.py
cd ./sampler
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 unconditional_sample.py --world_size 4

About

Official PyTorch implementation of PDAE (NeurIPS 2022)


Languages

Language:Python 100.0%