The official code for the paper DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps (Neurips 2022 Oral) and DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li and Jun Zhu.
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. DPM-Solver is suitable for both discrete-time and continuous-time diffusion models without any further training. Experimental results show that DPM-Solver can generate high-quality samples in only 10 to 20 function evaluations on various datasets.
Guided-Diffusion with DPM-Solver:
Stable-Diffusion with DPM-Solver++:
DiffEdit with DPM-Solver++:
🤗 Diffusers is a fantastic library for diffusion models. It supports both DPM-Solver and DPM-Solver++. The multistep DPM-Solver++ is the fastest solver currently.
The second-order multistep DPM-Solver++ is the default solver for Stable-Diffusion online demos (e.g., see example) and can also be used in LoRA (e.g., see example). Here is an example:
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
model_id = "stabilityai/stable-diffusion-2-1"
# Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
We recommend the SDE version DPM-Solver++ for the stage-1, and the ODE version DPM-Solver++ for the upscaling stages (both stage-2 and 3).
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import pt_to_pil
import torch
# stage 1
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_1.enable_model_cpu_offload()
# stage 2
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
)
stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_2.enable_model_cpu_offload()
# stage 3
safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}
stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_3.enable_model_cpu_offload()
def set_scheduler(stage):
if scheduler_name == 'dpm++':
scheduler = DPMSolverMultistepScheduler.from_config(stage.scheduler.config)
scheduler.config.algorithm_type = 'dpmsolver++'
elif scheduler_name == 'sde-dpm++':
scheduler = DPMSolverMultistepScheduler.from_config(stage.scheduler.config)
scheduler.config.algorithm_type = 'sde-dpmsolver++'
stage.scheduler = scheduler
return stage
upscale_steps = 25
stage_1 = set_scheduler(stage_1, 'sde-dpm++')
stage_2 = set_scheduler(stage_2, 'dpm++')
stage_3 = set_scheduler(stage_3, 'dpm++')
prompt = "casual photo of a leaf maple syrup glass container sitting on a wooden table in a log cabin, high depth of field during golden hour as the sunlight shines through the windows, dusty air"
# text embeds
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
generator = torch.manual_seed(0)
# stage 1
image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
pt_to_pil(image)[0].save("./if_stage_I.png")
# stage 2
image = stage_2(
image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt", num_inference_steps=upscale_steps
).images
pt_to_pil(image)[0].save("./if_stage_II.png")
# stage 3
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100, num_inference_steps=upscale_steps).images
image[0].save("./if_stage_III.png")
- News
- Supported Models and Algorithms
- Code Examples
- Use DPM-Solver in your own code
- Documentation
- TODO List
- References
-
DPM-Solver has been used in:
- DreamStudio and StableBoost (thanks for the implementations by Katherine Crowson's k-diffusion repo).
- Stable-Diffusion-WebUI, which supports both DPM-Solver and DPM-Solver++. DPM-Solver++2M is the fastest solver currently. Also many Thanks to Katherine Crowson's k-diffusion repo.
- Diffusers, a widely-used library for diffusion models.
- Stable-Diffusion v2 Official Code and Stable-Diffusion v1 Official Code. Check this pull request.
- Stable-Diffusion v2.0 Online Demo and Stable-Diffusion v1.5 Online Demo on HuggingFace, which uses DPM-Solver in Diffusers.
- Core ML Stable Diffusion, by Apple, and Swift Core ML Diffusers by Hugging Face.
-
2022-12-15. Swift port of DPM-Solver++ (restricted to order 2, no dynamic thresholding), used in Apple's Core ML Stable Diffusion library and 🤗 Hugging Face Swift Core ML Diffusers demo app: Source code, App Store link. See this PR for details.
-
2022-11-11. The official demo of stable-diffusion in HuggingFace Spaces 🤗 uses DPM-Solver and runs twice as fast! (From 50 steps to 25 steps.) It can generate 8 images within only 4 seconds using JAX on TPUv2-8. Check this twitter.
-
2022-11-08. We provide an online demo for DPM-Solver with stable-diffusion. Many thanks for the help and harware resource support by HuggingFace 🤗!
-
2022-11-07. Happy to announce that the multistep DPM-Solver has been supported by diffusers! Thanks for all the efforts of huggingface team (and me ^_^). Check this PR for details.
-
2022-10-26. We have updated the DPM-Solver v2.0, a more stable version for high-resolutional image synthesis tasks. We have the following upgrades:
-
We support the discrete-time DPMs by implementing a picewise linear interpolation of
$\log\alpha_t$ for theNoiseScheduleVP
.We strongly recommend to use the new implementation for discrete-time DPMs, especially for high-resolutional image synthesis. You can set
schedule='discrete'
to use the corresponding noise schedule. We also change the mapping between discrete-time inputs and continuous-time inputs in themodel_wrapper
, which has a consistent converged results with the other solvers. -
We change the API for
model_wrapper
:- We support four types of diffusion models: noise prediction model, data prediction model, velocity prediction model, score function.
- We support unconditional sampling, classifier guidance sampling and classifier-free guidance sampling.
-
We support new algorithms for DPM-Solver, which greatly improve the high-resolutional image sample quality by guided sampling.
- We support both DPM-Solver and DPM-Solver++. For DPM-Solver++, we further support the dynamic thresholding introduced by Imagen.
- We support both singlestep solver (i.e. Runge-Kutta-like solver) and multistep solver (i.e. Adams-Bashforth-like solver) for DPM-Solver, including order 1, 2, 3.
-
We support the following four types of diffusion models. You can set the model type by the argument model_type
in the function model_wrapper
.
Model Type | Training Objective | Example Paper |
---|---|---|
"noise": noise prediction model |
DDPM, Stable-Diffusion | |
"x_start": data prediction model |
DALL·E 2 | |
"v": velocity prediction model |
Imagen Video | |
"score": marginal score function |
ScoreSDE |
We support the following three types of sampling by diffusion models. You can set the argument guidance_type
in the function model_wrapper
.
Sampling Type | Equation for Noise Prediction Model | Example Paper |
---|---|---|
"uncond": unconditional sampling | DDPM | |
"classifier": classifier guidance | ADM, GLIDE | |
"classifier-free": classifier-free guidance | DALL·E 2, Imagen, Stable-Diffusion |
We support the following four algorithms. The algorithms are DPM-Solver and DPM-Solver++.
We also support the dynamic thresholding introduced by Imagen for algorithms with data-prediction. The dynamic thresholding method can further improve the sample quality by pixel-space DPMs with large guidance scales.
Note that the model_fn
for initializing DPM-Solver is always the noise prediction model. The setting for algorithm_type
is for the algorithm (DPM-Solver or DPM-Solver++), not for the model. In other words, both DPM-Solver and DPM-Solver++ is suitable for all the four model types.
- In fact, we implement the algorithms of DPM-Solver++ by firstly converting the noise prediction model to the data prediction model and then use DPM-Solver++ to sample, and users do not need to care about it.
The performance of singlestep solvers (i.e. Runge-Kutta-like solvers) and the multistep solvers (i.e. Adams-Bashforth-like solvers) are different. We recommend to use different solvers for different tasks.
Method | Supported Orders | Supporting Thresholding | Remark |
---|---|---|---|
DPM-Solver, singlestep | 1, 2, 3 | No | |
DPM-Solver, multistep | 1, 2, 3 | No | |
DPM-Solver++, singlestep | 1, 2, 3 | Yes | |
DPM-Solver++, multistep | 1, 2, 3 | Yes | Recommended for guided sampling with order = 2 , and for unconditional sampling with order = 3 . |
We provide an example of guided-diffusion with DPM-Solver in examples/ddpm_and_guided-diffusion
.
We provide an example of stable diffusion with DPM-Solver in examples/stable-diffusion
. DPM-Solver can greatly accelerate the sampling speed of the original stable-diffusion.
We provide an example of DiffEdit with DPM-Solver, which can be used for image editing. The idea of DiffEdit can be general decribe as, using DDIM to get a invertable latent serise, then apply different prompt for inpainting (controled by auto generated mask).
We could easily accelerate such editing / inpainting by DPM-Solver in only 20 steps.
We provide a pytorch example and a JAX example in examples/
which apply DPM-Solver for Yang Song's score_sde repo on CIFAR-10.
It is very easy to combine DPM-Solver with your own diffusion models. We support both Pytorch and JAX code. You can just copy the file dpm_solver_pytorch.py
or dpm_solver_jax.py
to your own code files and import it.
In each step, DPM-Solver needs to compute the corresponding
-
For discrete-time DPMs, we support a picewise linear interpolation of
$\log\alpha_t$ in theNoiseScheduleVP
class. It can support all types of VP noise schedules. -
For continuous-time DPMs, we support linear schedule (as used in DDPM and ScoreSDE) in the
NoiseScheduleVP
class.
Moreover, DPM-Solver is designed for the continuous-time diffusion ODEs. For discrete-time diffusion models, we also implement a wrapper function to convert the discrete-time diffusion models to the continuous-time diffusion models in the model_wrapper
function.
If you want to find the best setting for accelerating the sampling procedure by your own diffusion models, we provide a reference guide here:
-
IMPORTANT: First run 1000-step DDIM to check the sample quality of your model. If the sample quality is poor, then DPM-Solver cannot improve it. Please further check your model defination or training process.
Reason: DDIM is the first-order special case of DPM-Solver (proved in our paper). So given the same noise sample at time
$T$ , the converged samples of DDIM and DPM-Solver are the same. DPM-Solver can accelerate the convergence, but cannot improve the converged sample quality. -
If 1000-step DDIM can generate quite good samples, then DPM-Solver can achieve a quite good sample quality within very few steps because it can greatly accelerate the convergence. You may want to further choose the detailed hyperparameters of DPM-Solver. Here we provide a comprehensive searching routine:
-
Comparing
algorithm_type="dpmsolver"
andalgorithm_type="dpmsolver++"
. Note that these settings are for the algorithm, not for the model. In other words, even foralgorithm_type="dpmsolver++
, you can still use the noise prediction model (such as stable-diffusion) and the algorithm can work well. -
(Optional) Comparing with / without dynamic thresholding.
IMPORTANT: our supported dynamic thresholding method is only valid for pixel-space diffusion models with
algorithm_type="dpmsolver++
. For example, Imagen uses the dynamic thresholding method and greatly improves the sample quality. The thresholding method pushes the pixel-space samples into the bounded area, so it can generate reasonable images. However, for latent-space diffusion models (such as stable-diffusion), the thresholding method is unsuitable because the$x_0$ at time$0$ of the diffusion model is in fact the "latent variable" in the latent space and it is unbounded. -
Comparing
singlestep
ormultistep
methods. -
Comparing
order = 2, 3
. Note that the all the first-order versions are equivalent to DDIM, so you do not need to try it. -
Comparing
steps = 10, 15, 20, 25, 50, 100
. It depends on your computation resources and the need of sample quality. -
(Optional) Comparing the
time_uniform
,logSNR
andtime_quadratic
for the skip type.We empirically find that for high-resolutional images, the best setting is the
time_uniform
. So we recommend this setting and there is no need for extra searching. However, for low-resolutional images such as CIFAR-10, we empirically find thatlogSNR
is the best setting. -
(Optional) Comparing
denoise_to_zero=True
ordenoise_to_zero=False
.Empirically, the
denoise_to_zero=True
can improve the FID for low-resolutional images such as CIFAR-10. However, the influence of this method for high-resolutional images seem to be small. As the denoise_to_zero method needs one additional function evaluation (i.e. one additional step), we do not recommend to use the denoise_to_zero method for high-resolutional images.
The detailed pseudo code is like:
for algorithm_type in ["dpmsolver", "dpmsolver++"]: # Optional, for correcting_x0_fn in [None, "dynamic_thresholding"]: dpm_solver = DPM_Solver(..., algorithm_type=algorithm_type) # ... means other arguments for method in ['singlestep', 'multistep']: for order in [2, 3]: for steps in [10, 15, 20, 25, 50, 100]: sample = dpm_solver.sample( ..., # ... means other arguments method=method, order=order, steps=steps, # optional: skip_type='time_uniform' or 'logSNR' or 'time_quadratic', # optional: denoise_to_zero=True or False )
And then compare the samples to choose the best setting.
-
Moreover, for unconditional sampling and guided sampling, we have some recommendation settings and code examples, which are listed in the following section.
We recommend to use the following two types of solvers for different tasks:
-
3rd-order multistep DPM-Solver:
## Define the model and noise schedule (see examples below) ## .... ## Define DPM-Solver and compute the sample. dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") ## Or also try: ## dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") ## Steps in [10, 20] can generate quite good samples. ## And steps = 20 can almost converge. x_sample = dpm_solver.sample( x_T, steps=20, order=3, skip_type="time_uniform", method="multistep", )
-
2nd-order multistep DPM-Solver:
- For general DPMs (e.g. latent-space DPMs):
## Define the model and noise schedule (see examples below) ## .... ## Define DPM-Solver and compute the sample. dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") ## Steps in [10, 20] can generate quite good samples. ## And steps = 20 can almost converge. x_sample = dpm_solver.sample( x_T, steps=20, order=2, skip_type="time_uniform", method="multistep", )
- For DPMs trained on bounded data (e.g. pixel-space images), we further support the dynamic thresholding method introduced by Imagen by setting
correcting_x0_fn = "dynamic_thresholding"
. The dynamic thresholding method can greatly improve the sample quality of pixel-space DPMs by guided sampling with large guidance scales.## Define the model and noise schedule (see examples below) ## .... ## Define DPM-Solver and compute the sample. dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++", correcting_x0_fn="dynamic_thresholding") ## Steps in [10, 20] can generate quite good samples. ## And steps = 20 can almost converge. x_sample = dpm_solver.sample( x_T, steps=20, order=2, skip_type="time_uniform", method="multistep", )
- For general DPMs (e.g. latent-space DPMs):
Specifically, we have the following suggestions:
-
For unconditional sampling:
- For obtaining a not too bad sample as fast as possible, use the 2nd-order (dpmsolver++, multistep) DPM-Solver with
steps
<= 10. - For obtaining a good sample, use the 3rd-order (dpmsolver or dpmsolver++, multistep) DPM-Solver with
steps
= 15. - (Recommended) For obtaining an almost converged sample, use the 3rd-order (dpmsolver or dpmsolver++, multistep) DPM-Solver with
steps
= 20. - For obtaining an absolutely converged sample, use the 3rd-order (dpmsolver or dpmsolver++, multistep) DPM-Solver with
steps
= 50.
- For obtaining a not too bad sample as fast as possible, use the 2nd-order (dpmsolver++, multistep) DPM-Solver with
-
For guided sampling (especially with large guidance scales):
- Use the 2nd-order (dpmsolver++, multistep) DPM-Solver for all steps.
- For pixel-space DPMs (i.e. DPMs trained on images), set
correcting_x0_fn="dynamic_thresholding"
; else (e.g. latent-space DPMs) setcorrecting_x0_fn=None
. - Choices for
steps
:- For obtaining a not too bad sample as fast as possible, use
steps
<= 10. - For obtaining a good sample, use
steps
= 15. - (Recommended) For obtaining an almost converged sample, use
steps
= 20. - For obtaining an absolutely converged sample, use
steps
= 50.
- For obtaining a not too bad sample as fast as possible, use
We recommend to use the 3rd-order (dpmsolver or dpmsolver++, multistep) DPM-Solver. Here is an example for discrete-time DPMs:
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
## You need to firstly define your model and the extra inputs of your model,
## And initialize an `x_T` from the standard normal distribution.
## `model` has the format: model(x_t, t_input, **model_kwargs).
## If your model has no extra inputs, just let model_kwargs = {}.
## If you use discrete-time DPMs, you need to further define the
## beta arrays for the noise schedule.
# model = ....
# model_kwargs = {...}
# x_T = ...
# betas = ....
## 1. Define the noise schedule.
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type="noise", # or "x_start" or "v" or "score"
model_kwargs=model_kwargs,
)
## 3. Define dpm-solver and sample by singlestep DPM-Solver.
## (We recommend singlestep DPM-Solver for unconditional sampling)
## You can adjust the `steps` to balance the computation
## costs and the sample quality.
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
## Can also try
# dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
## You can use steps = 10, 12, 15, 20, 25, 50, 100.
## Empirically, we find that steps in [10, 20] can generate quite good samples.
## And steps = 20 can almost converge.
x_sample = dpm_solver.sample(
x_T,
steps=20,
order=3,
skip_type="time_uniform",
method="multistep",
)
We recommend to use the 2nd-order (dpmsolver++, multistep) DPM-Solver, especially for large guidance scales. Here is an example for discrete-time DPMs:
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
## You need to firstly define your model and the extra inputs of your model,
## And initialize an `x_T` from the standard normal distribution.
## `model` has the format: model(x_t, t_input, **model_kwargs).
## If your model has no extra inputs, just let model_kwargs = {}.
## If you use discrete-time DPMs, you need to further define the
## beta arrays for the noise schedule.
## For classifier guidance, you need to further define a classifier function,
## a guidance scale and a condition variable.
# model = ....
# model_kwargs = {...}
# x_T = ...
# condition = ...
# betas = ....
# classifier = ...
# classifier_kwargs = {...}
# guidance_scale = ...
## 1. Define the noise schedule.
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type="noise", # or "x_start" or "v" or "score"
model_kwargs=model_kwargs,
guidance_type="classifier",
condition=condition,
guidance_scale=guidance_scale,
classifier_fn=classifier,
classifier_kwargs=classifier_kwargs,
)
## 3. Define dpm-solver and sample by multistep DPM-Solver.
## (We recommend multistep DPM-Solver for conditional sampling)
## You can adjust the `steps` to balance the computation
## costs and the sample quality.
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
## If the DPM is defined on pixel-space images, you can further
## set `correcting_x0_fn="dynamic_thresholding"`. e.g.:
# dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++",
# correcting_x0_fn="dynamic_thresholding")
## You can use steps = 10, 12, 15, 20, 25, 50, 100.
## Empirically, we find that steps in [10, 20] can generate quite good samples.
## And steps = 20 can almost converge.
x_sample = dpm_solver.sample(
x_T,
steps=20,
order=2,
skip_type="time_uniform",
method="multistep",
)
We recommend to use the 2nd-order (dpmsolver++, multistep) DPM-Solver, especially for large guidance scales. Here is an example for discrete-time DPMs:
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
## You need to firstly define your model and the extra inputs of your model,
## And initialize an `x_T` from the standard normal distribution.
## `model` has the format: model(x_t, t_input, cond, **model_kwargs).
## If your model has no extra inputs, just let model_kwargs = {}.
## If you use discrete-time DPMs, you need to further define the
## beta arrays for the noise schedule.
## For classifier-free guidance, you need to further define a guidance scale,
## a condition variable and an unconditioanal condition variable.
# model = ....
# model_kwargs = {...}
# x_T = ...
# condition = ...
# unconditional_condition = ...
# betas = ....
# guidance_scale = ...
## 1. Define the noise schedule.
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type="noise", # or "x_start" or "v" or "score"
model_kwargs=model_kwargs,
guidance_type="classifier-free",
condition=condition,
unconditional_condition=unconditional_condition,
guidance_scale=guidance_scale,
)
## 3. Define dpm-solver and sample by multistep DPM-Solver.
## (We recommend multistep DPM-Solver for conditional sampling)
## You can adjust the `steps` to balance the computation
## costs and the sample quality.
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
## If the DPM is defined on pixel-space images, you can further
## set `correcting_x0_fn="dynamic_thresholding"`. e.g.:
# dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++",
# correcting_x0_fn="dynamic_thresholding")
## You can use steps = 10, 12, 15, 20, 25, 50, 100.
## Empirically, we find that steps in [10, 20] can generate quite good samples.
## And steps = 20 can almost converge.
x_sample = dpm_solver.sample(
x_T,
steps=20,
order=2,
skip_type="time_uniform",
method="multistep",
)
We support the commonly-used variance preserving (VP) noise schedule for both discrete-time and continuous-time DPMs:
We support a picewise linear interpolation of NoiseScheduleVP
class. It can support all types of VP noise schedules.
We need either the
Define the discrete-time noise schedule by the
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
Or define the discrete-time noise schedule by the
noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod)
We support both linear schedule (as used in DDPM and ScoreSDE) and cosine schedule (as used in improved-DDPM) for the continuous-time DPMs.
Define the continuous-time linear noise schedule:
noise_schedule = NoiseScheduleVP(schedule='linear', continuous_beta_0=0.1, continuous_beta_1=20.)
Define the continuous-time cosine noise schedule:
noise_schedule = NoiseScheduleVP(schedule='cosine')
For a given diffusion model
with an input of the time label
(may be discrete-time labels (i.e. 0 to 999) or continuous-time times (i.e. 0 to 1)), and the output type of the model may be "noise" or "x_start" or "v" or "score", we wrap the model function to the following format:
model_fn(x, t_continuous) -> noise
where t_continuous
is the continuous time labels (i.e. 0 to 1), and the output type of the model is "noise", i.e. a noise prediction model. And we use the continuous-time noise prediction model model_fn
for DPM-Solver.
Note that DPM-Solver only needs the noise prediction model (the
After defining the noise schedule, we need to further wrap the model
to a continuous-time noise prediction model. The given model
has the following format:
model(x_t, t_input, **model_kwargs) -> noise | x_start | v | score
And we wrap the model by:
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type, # "noise" or "x_start" or "v" or "score"
model_kwargs=model_kwargs,
)
where model_kwargs
is the additional inputs of the model, and the model_type
can be "noise" or "x_start" or "v" or "score".
After defining the noise schedule, we need to further wrap the model
to a continuous-time noise prediction model. The given model
has the following format:
model(x_t, t_input, **model_kwargs) -> noise | x_start | v | score
For DPMs with classifier guidance, we also combine the model output with the classifier gradient. We need to specify the classifier function and the guidance scale. The classifier function has the following format:
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
where t_input
is the same time label as in the original diffusion model model
, and cond
is the condition variable, and classifier_kwargs
is the other inputs of the classifier function.
And we wrap the model by:
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type, # "noise" or "x_start" or "v" or "score"
model_kwargs=model_kwargs,
guidance_type="classifier",
condition=condition,
guidance_scale=guidance_scale,
classifier_fn=classifier,
classifier_kwargs=classifier_kwargs,
)
where model_kwargs
is the additional inputs of the model, and the model_type
can be "noise" or "x_start" or "v" or "score", and guidance_scale
is the classifier guidance scale, and condition
is the conditional input of the classifier.
After defining the noise schedule, we need to further wrap the model
to a continuous-time noise prediction model. The given model
has the following format:
model(x_t, t_input, cond, **model_kwargs) -> noise | x_start | v | score
Note that for classifier-free guidance, the model needs another input cond
. And if cond
is a special variable unconditional_condition
, the model output is the unconditional DPM output.
And we wrap the model by:
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type, # "noise" or "x_start" or "v" or "score"
model_kwargs=model_kwargs,
guidance_type="classifier-free",
condition=condition,
unconditional_condition=unconditional_condition,
guidance_scale=guidance_scale,
)
where model_kwargs
is the additional inputs of the model, and the model_type
can be "noise" or "x_start" or "v" or "score", and guidance_scale
is the classifier guidance scale, and condition
is the conditional input, and unconditional_condition
is the special unconditional condition variable for the unconditional model.
Below we introduce the detailed mapping between the discrete-time labels and the continuous-time times. However, to use DPM-Solver, it is not necessary to understand the following details.
For discrete-time DPMs, the noise prediction model noise-prediction is trained for the discrete-time labels from
i.e. we map the discrete-time label
For continuous-time DPMs from defined by
After defining the model_fn
by the function model_wrapper
, we can further use model_fn
to define DPM-Solver and compute samples.
We support the following four algorithms. The algorithms are DPM-Solver and DPM-Solver++.
We also support the dynamic thresholding introduced by Imagen for algorithms with data-prediction. The dynamic thresholding method can further improve the sample quality by pixel-space DPMs with large guidance scales.
Note that the model_fn
for initializing DPM-Solver is always the noise prediction model. The setting for algorithm_type
is for the algorithm (DPM-Solver or DPM-Solver++), not for the model. In other words, both DPM-Solver and DPM-Solver++ is suitable for all the four model types.
- In fact, we implement the algorithms of DPM-Solver++ by firstly converting the noise prediction model to the data prediction model and then use DPM-Solver++ to sample, and users do not need to care about it.
The performance of singlestep solvers (i.e. Runge-Kutta-like solvers) and the multistep solvers (i.e. Adams-Bashforth-like solvers) are different. We recommend to use different solvers for different tasks.
Method | Supported Orders | Supporting Thresholding | Remark |
---|---|---|---|
DPM-Solver, singlestep | 1, 2, 3 | No | |
DPM-Solver, multistep | 1, 2, 3 | No | |
DPM-Solver++, singlestep | 1, 2, 3 | Yes | |
DPM-Solver++, multistep | 1, 2, 3 | Yes | Recommended for guided sampling with order = 2 , and for unconditional sampling with order = 3 . |
-
For DPM-Solver with "dpmsolver" algorithm, define
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
-
For DPM-Solver with "dpmsolver++" algorithm, define
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
-
For DPM-Solver with "dpmsolver++" and applying dynamic thresholding, define
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++", correcting_x0_fn="dynamic_thresholding", thresholding_max_val=1.0)
You can use dpm_solver.sample
to quickly sample from DPMs. This function computes the ODE solution at time t_end
by DPM-Solver, given the initial x
at time t_start
.
We support the following algorithms:
-
Singlestep DPM-Solver. We combine all the singlestep solvers with order <=
order
to use up all the function evaluations (steps). -
Multistep DPM-Solver.
-
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1, DPM-Solver-2 and DPM-Solver-3).
-
Adaptive step size DPM-Solver. (i.e. DPM-Solver-12 and DPM-Solver-23)
We support three types of skip_type
for the choice of intermediate time steps:
-
logSNR
: uniform logSNR for the time steps. Recommended for low-resolutional images. -
time_uniform
: uniform time for the time steps. Recommended for high-resolutional images. -
time_quadratic
: quadratic time for the time steps.
We combine all the singlestep solvers with order <= order
to use up all the function evaluations (steps). The total number of function evaluations (NFE) == steps
.
For discrete-time DPMs, we do not need to specify the t_start
and t_end
. The default setting is to sample from the discrete-time label
## discrete-time DPMs
x_sample = dpm_solver.sample(
x_T,
steps=20,
order=3,
skip_type="time_uniform",
method="singlestep",
)
For continuous-time DPMs, we sample from t_start=1.0
(the default setting) to t_end
. We recommend t_end=1e-3
for steps <= 15
, and t_end=1e-4
for steps > 15
. For example:
x_sample = dpm_solver.sample(
x_T,
t_end=1e-3,
steps=12,
order=3,
skip_type="time_uniform",
method="singlestep",
)
## continuous-time DPMs
x_sample = dpm_solver.sample(
x_T,
t_end=1e-4,
steps=20,
order=3,
skip_type="time_uniform",
method="singlestep",
)
Given a fixed NFE == steps
, the sampling procedure is:
- If
order
== 1:- Denote K =
steps
. We use K steps of DPM-Solver-1 (i.e. DDIM).
- Denote K =
- If
order
== 2:- Denote K = (
steps
// 2) + (steps
% 2). We take K intermediate time steps for sampling. - If
steps
% 2 == 0, we use K steps of singlestep DPM-Solver-2. - If
steps
% 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- Denote K = (
- If
order
== 3:- Denote K = (
steps
// 3 + 1). We take K intermediate time steps for sampling. - If
steps
% 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - If
steps
% 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - If
steps
% 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
- Denote K = (
For discrete-time DPMs, we do not need to specify the t_start
and t_end
. The default setting is to sample from the discrete-time label
## discrete-time DPMs
x_sample = dpm_solver.sample(
x_T,
steps=20,
order=2,
skip_type="time_uniform",
method="multistep",
)
For continuous-time DPMs, we sample from t_start=1.0
(the default setting) to t_end
. We recommend t_end=1e-3
for steps <= 15
, and t_end=1e-4
for steps > 15
. For example:
x_sample = dpm_solver.sample(
x_T,
t_end=1e-3,
steps=10,
order=2,
skip_type="time_uniform",
method="multistep",
)
## continuous-time DPMs
x_sample = dpm_solver.sample(
x_T,
t_end=1e-4,
steps=20,
order=3,
skip_type="time_uniform",
method="multistep",
)
We initialize the first order
values by lower order multistep solvers.
Given a fixed NFE == steps
, the sampling procedure is:
- Denote K =
steps
. - If
order
== 1:- We use K steps of DPM-Solver-1 (i.e. DDIM).
- If
order
== 2:- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
- If
order
== 3:- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
For continuous-time DPMs, we recommend t_end=1e-4
for better sample quality.
We ignore steps
and use adaptive step size DPM-Solver with a higher order of order
.
You can adjust the absolute tolerance atol
and the relative tolerance rtol
to balance the computatation costs (NFE) and the sample quality. For image data, we recommend atol=0.0078
(the default setting).
- If
order
== 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - If
order
== 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
For example, to sample by DPM-Solver-12:
x_sample = dpm_solver.sample(
x_T,
t_end=1e-4,
order=2,
method="adaptive",
rtol=0.05,
)
We use DPM-Solver-order
for order
= 1 or 2 or 3, with total [steps
// order
] * order
NFE.
For example, to sample by DPM-Solver-3:
x_sample = dpm_solver.sample(
x_T,
steps=30,
order=3,
skip_type="time_uniform",
method="singlestep_fixed",
)
- Add stable-diffusion examples.
- Support Diffusers.
- Documentation for example code.
- Clean and add the JAX code example.
- Add more explanations about DPM-Solver.
- Add a small jupyter example.
- Add VE type noise schedule.
- Support downstream applications (e.g. inpainting, etc.).
If you find the code useful for your research, please consider citing
@article{lu2022dpm,
title={DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps},
author={Lu, Cheng and Zhou, Yuhao and Bao, Fan and Chen, Jianfei and Li, Chongxuan and Zhu, Jun},
journal={arXiv preprint arXiv:2206.00927},
year={2022}
}