iSach / functional-diffusion-processes

Fork of Continuous-Time Functional Diffusion Processes (NeurIPS 2023).

Home Page:https://openreview.net/pdf?id=VPrir0p5b6

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Continuous-Time Functional Diffusion Processes

Python 3.10 OpenReview NeurIPS 2023

Authors: Giulio Franzese, Giulio Corallo, Simone Rossi, Markus Heinonen, Maurizio Filippone, Pietro Michiardi

Accepted as a poster at the 37th Conference on Neural Information Processing Systems (NeurIPS 2023).


Abstract

We introduce Functional Diffusion Processes (FDPs), which generalize score-based diffusion models to infinite-dimensional function spaces. FDPs require a new mathematical framework to describe the forward and backward dynamics, and several extensions to derive practical training objectives. These include infinite-dimensional versions of Girsanov theorem, in order to be able to compute an ELBO, and of the sampling theorem, in order to guarantee that functional evaluations in a countable set of points are equivalent to infinite-dimensional functions. We use FDPs to build a new breed of generative models in function spaces, which do not require specialized network architectures, and that can work with any kind of continuous data. Our results on real data show that FDPs achieve high-quality image generation, using a simple MLP architecture with orders of magnitude fewer parameters than existing diffusion models.

Super Resolution on MNIST Samples from CELEBA


Installation

pip install git+ssh://git@github.com/giulio98/functional-diffusion-processes.git

Quickstart

Setup the Development Environment

git clone git@github.com:giulio98/functional-diffusion-processes.git
cd functional-diffusion-processes
conda env create -f env.yaml
conda activate fdp
pip install -e .[dev]
pre-commit install

Working with Hydra

Hydra is a framework that simplifies the configuration of complex applications, including the management of hierarchical configurations. It's particularly useful for running experiments with different hyperparameters, which is a key part of the experimentation done in this project.

With Hydra, you can easily sweep over parameters sequentially, which is demonstrated in the experiments sections below. Additionally, Hydra supports parallel experiments execution through Joblib plugin. This allows for concurrent execution of multiple experiment configurations, significantly speeding up the experimentation process when working with multiple hyperparameters.


Setup the Project

Before you begin with any experiments, ensure to create a .env file with the following content:

export WANDB_API_KEY=<your wandb api key>
export HOME=<your_home_directory>  # e.g., /home/username
export CUDA_HOME=/usr/local/cuda
export PROJECT_ROOT=<your_project_directory>  # /home/username/functional_diffusion_processes
export DATA_ROOT=${PROJECT_ROOT}/data
export LOGS_ROOT=${PROJECT_ROOT}/logs
export TFDS_DATA_DIR=${DATA_ROOT}/tensorflow_datasets
export PYTHONPATH=${PROJECT_ROOT}
export PYTHONUNBUFFERED=1
export HYDRA_FULL_ERROR=1
export WANDB_DISABLE_SERVICE=true
export CUDA_VISIBLE_DEVICES=<your cuda devices>

All experiments utilize wandb for logging. However, you can opt out of using wandb by setting trainer_logging.use_wandb=False in the yaml files in conf/trainers/trainer_maml and conf/trainers/trainer_vit.

Pretrained Checkpoints

In order to run the sampling and conditional generation experiments, you need to download the pretrained checkpoints.

All checkpoints are provided in this Google Drive

Alternatively you can download them directly by running:

pip install gdown
gdown --id 1R9aRsV7q4yU0ey47tR7hFvKttEilUv0i
unzip logs.zip
rm logs.zip

Experiments

MNIST Experiments

Find the configurations for our paper's experiments under conf/experiments_maml, with corresponding scripts in scripts/maml.

Training

Run the default training script, or use Hydra to experiment with hyperparameters:

# Default training
sh scripts/maml/train_mnist.sh

# Hyperparameter experimentation
python3 src/functional_diffusion_processes/run.py --multirun +experiments_maml=exp_mnist \
trainers.training_config.learning_rate=1e-5,2e-5

Generation

sh scripts/maml/sample_mnist.sh

Super Resolution

Run the script as-is for 128x128 resolution, or specify a different target shape:

# Default resolution
sh scripts/maml/super_resolution_mnist.sh

# Custom resolution
python3 src/functional_diffusion_processes/run.py --multirun +experiments_maml=exp_mnist_super_resolution \
samplers.sampler_config.target_shape=[512,512]

FID Evaluation

sh scripts/maml/eval_mnist.sh

CELEBA Experiments

Download CELEBA Dataset

pip install gdown
cd ~/functional-diffusion-processes/data/tensorflow_datasets/
gdown --folder https://drive.google.com/drive/folders/1eHdU3N4Tiv6BAezAAI7LAvJTItIF8GD2?usp=share_link

Scripts for training and evaluating models on the CELEBA dataset are provided, using official configurations.

Training

Train the INR or the UViT on the CELEBA dataset:

sh scripts/maml/train_celeba.sh  # for INR
sh scripts/vit/train_celeba.sh   # for UViT

Generation and Conditional Generation

sh scripts/maml/sample_celeba.sh  # INR
sh scripts/vit/sample_celeba.sh   # UViT
sh scripts/maml/colorize_celeba.sh  # Colorization
sh scripts/maml/deblur_celeba.sh    # Deblurring
sh scripts/maml/inpaint_celeba.sh   # Inpainting

FID Evaluation

sh scripts/maml/eval_celeba.sh  # INR
sh scripts/vit/eval_celeba.sh   # UViT

Acknowledgements

Our code builds upon several outstanding open source projects and papers:

Citation

If you use our code or paper, please cite:

@misc{franzese2023continuoustime,
      title={Continuous-Time Functional Diffusion Processes},
      author={Giulio Franzese and Giulio Corallo and Simone Rossi and Markus Heinonen and Maurizio Filippone and Pietro Michiardi},
      year={2023},
      eprint={2303.00800},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

Fork of Continuous-Time Functional Diffusion Processes (NeurIPS 2023).

https://openreview.net/pdf?id=VPrir0p5b6

License:Apache License 2.0


Languages

Language:Python 98.9%Language:Shell 1.1%