DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps
The official code for the paper DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li and Jun Zhu.
DPM-Solver is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. DPM-Solver is suitable for both discrete-time and continuous-time diffusion models without any further training. Experimental results show that DPM-Solver can generate high-quality samples in only 10 to 20 function evaluations on various datasets.
Use DPM-Solver in your own code
It is very easy to combine DPM-Solver with your own diffusion models. We support both Pytorch and JAX code. You can just copy the file dpm_solver_pytorch.py
or dpm_solver_jax.py
(The JAX code is cleaning and will be released soon) to your own code files and import it.
In each step, DPM-Solver needs to compute the corresponding NoiseScheduleVP
class.
Moreover, DPM-Solver is designed for the continuous-time diffusion ODEs. For discrete-time diffusion models, we also implement a wrapper function to convert the discrete-time diffusion models to the continuous-time diffusion models in the model_wrapper
function.
An example for using DPM-Solver:
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
## You need to firstly define your model and the extra inputs of your model,
## And initialize an `x_T` from the standard normal distribution.
## If your model has no extra inputs, just let model_kwargs = {}.
## `model` has the format: model(x_t, t_input, **model_kwargs).
# model = ....
# model_kwargs = {...}
# x_T = ...
## 1. Define the noise schedule.
## We support the 'linear' or 'cosine' VP schedule.
noise_schedule = NoiseScheduleVP(schedule='linear')
## 2. Convert your discrete-time noise prediction model `model`
## to the continuous-time noise prediction model.
model_fn = model_wrapper(
model,
noise_schedule,
is_cond_classifier=False,
total_N=1000,
model_kwargs=model_kwargs
)
## 3. Define dpm-solver and sample by dpm-solver-fast (recommended).
## You can adjust the `steps` to balance the computation
## costs and the sample quality.
dpm_solver = DPM_Solver(model_fn, noise_schedule)
x_sample = dpm_solver.sample(
x_T,
steps=15,
eps=1e-4,
adaptive_step_size=False,
fast_version=True,
)
To be specific, you need to finish the following three simple steps.
1. Define the noise schedule.
We support the 'linear' or 'cosine' VP noise schedule. For example, for the commly-used linear schedule (i.e. the
from dpm_solver_pytorch import NoiseScheduleVP
noise_schedule = NoiseScheduleVP(schedule='linear')
If you want to custom your own designed noise schedule, you need to implement the marginal_log_mean_coeff
, marginal_std
, marginal_lambda
and inverse_lambda
functions of your noise schedule. Please refer to the detailed comments in the code of NoiseScheduleVP
.
2. Wrap your noise prediction model to the continuous-time model.
For a given noise prediction model (i.e. the $\epsilon_\theta(x_t, t)$) model
and model_kwargs
with the following format:
model(x_t, t_input, **model_kwargs)
where t_input
is the time label of the model.
(may be discrete-time labels (i.e. 0 to 999) or continuous-time labels (i.e. 0 to
We wrap the model function to the following format:
model_fn(x, t_continuous)
where t_continuous
is the continuous time labels (i.e. 0 to model_fn
for DPM-Solver.
Note that DPM-Solver only needs the noise prediction model (the "mean" model), so for diffusion models which predict both "mean" and "variance" (such as improved-DDPM), you need to firstly define another function by your model to only output the "mean".
If you want to custom your own designed model time input t_input
,
2.1. Continuous-time DPMs
For continuous-time DPMs, we have t_input = t_continuous
. You can let time_input_type
be "0"
to wrap the model function:
from dpm_solver_pytorch import model_wrapper
model_fn = model_wrapper(
model,
noise_schedule,
is_cond_classifier=False,
time_input_type="0",
model_kwargs=model_kwargs
)
2.2. Discrete-time DPMs
For discrete-time DPMs, we support two types for converting the discrete time to the continuous time (see Appendix in our paper). We recommend time_input_type = "1"
(the default setting). You also need to specify the total length of the discrete time (default is 1000
):
from dpm_solver_pytorch import model_wrapper
model_fn = model_wrapper(
model,
noise_schedule,
is_cond_classifier=False,
time_input_type="1",
total_N=1000,
model_kwargs=model_kwargs
)
2.3. DPMs with classifier guidance
For DPMs with classifier guidance, we also combine the model output with the classifier gradient. You need to specify the classifier function and the guidance scale. The classifier function has the following format:
classifier_fn(x_t, t_input)
where t_input
is the same time label as in the original diffusion model model
. For example, for discrete-time DPMs with classifier guidance:
from dpm_solver_pytorch import model_wrapper
model_fn = model_wrapper(
model,
noise_schedule,
is_cond_classifier=True,
classifier_fn=classifier_fn,
classifier_scale=1.,
time_input_type="1",
total_N=1000,
model_kwargs=model_kwargs
)
3. Define DPM-Solver and compute samples
Just let
from dpm_solver_pytorch import DPM_Solver
dpm_solver = DPM_Solver(model_fn, noise_schedule)
where model_fn
is the output of the function model_wrapper
.
You can use dpm_solver.sample
to quickly sample from DPMs. This function computes the sample at time eps
by DPM-Solver, given the initial x_T
at time T
.
We support the following algorithms:
-
Adaptive step size DPM-Solver (i.e. DPM-Solver-12 and DPM-Solver-23)
-
Fixed order DPM-Solver (i.e. DPM-Solver-1, DPM-Solver-2 and DPM-Solver-3).
-
Fast version of DPM-Solver (i.e. DPM-Solver-fast), which uses uniform logSNR steps and combine different orders of DPM-Solver.
We recommend DPM-Solver-fast for both fast sampling in few steps (<=20) and fast convergence in many steps (50 to 100).
3.1. Sampling by DPM-Solver-fast (recommended)
Let adaptive_step_size=False
and fast_version=True
.
We recommend eps=1e-3
for steps <= 15
, and eps=1e-4
for steps > 15
.
Given a fixed NFE=steps
, the sampling procedure by DPM-Solver-fast is:
-
Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
-
If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
-
If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
-
If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
For example, to sample with NFE = 15:
x_sample = dpm_solver.sample(
x_T,
steps=15,
eps=1e-4,
adaptive_step_size=False,
fast_version=True,
)
3.2. Sampling by adaptive step size DPM-Solver
Let adaptive_step_size=True
.
We recommend eps=1e-4
for better sample quality.
If order
=2, we use DPM-Solver-12 which combines DPM-Solver-1 and DPM-Solver-2.
If order
=3, we use DPM-Solver-23 which combines DPM-Solver-2 and DPM-Solver-3.
You can adjust the absolute tolerance atol
and the relative tolerance rtol
to balance the computatation costs (NFE) and the sample quality. For image data, we recommend atol=0.0078
(the default setting).
For example, to sample by DPM-Solver-12:
x_sample = dpm_solver.sample(
x_T,
eps=1e-4,
order=2,
adaptive_step_size=True,
fast_version=False,
rtol=0.05,
)
3.3. Sampling by DPM-Solver-k for k = 1, 2, 3
Let adaptive_step_size=False
and fast_version=False
.
We use DPM-Solver-order
for order
= 1 or 2 or 3, with total [steps
// order
] * order
NFE.
We support three types of skip_type
:
-
'logSNR': uniform logSNR for the time steps, recommended for DPM-Solver.
-
'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
-
'time_quadratic': quadratic time for the time steps. (Used in DDIM.)
For example, to sample by DPM-Solver-3:
x_sample = dpm_solver.sample(
x_T,
steps=30,
eps=1e-4,
order=3,
skip_type='logSNR',
adaptive_step_size=False,
fast_version=False,
)
Examples
We also add a pytorch example and a JAX example. The documentations are coming soon.
TODO List
- Documentation for example code.
- Clean and add the JAX code example.
- Add more explanations about DPM-Solver.
- Add VE type noise schedule.
References
If you find the code useful for your research, please consider citing
@article{lu2022dpm,
title={DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps},
author={Lu, Cheng and Zhou, Yuhao and Bao, Fan and Chen, Jianfei and Li, Chongxuan and Zhu, Jun},
journal={arXiv preprint arXiv:2206.00927},
year={2022}
}