Skip to content

Commit

Permalink
feat : add log-rho deis multistep scheduler (open-mmlab#1432)
Browse files Browse the repository at this point in the history
* feat : add log-rho deis multistep deis

* docs :fix typo

* docs : add docs for impl algo

* docs : remove duplicate ref

* finish deis

* add docs

* fix

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
qsh-zh and patrickvonplaten authored Jan 4, 2023
1 parent 9b63854 commit be99201
Show file tree
Hide file tree
Showing 10 changed files with 743 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@
title: "DDIM"
- local: api/schedulers/ddpm
title: "DDPM"
- local: api/schedulers/deis
title: "DEIS"
- local: api/schedulers/singlestep_dpm_solver
title: "Singlestep DPM-Solver"
- local: api/schedulers/multistep_dpm_solver
Expand Down
22 changes: 22 additions & 0 deletions docs/source/en/api/schedulers/deis.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# DEIS

Fast Sampling of Diffusion Models with Exponential Integrator.

## Overview

Original paper can be found [here](https://arxiv.org/abs/2204.13902). The original implementation can be found [here](https://github.com/qsh-zh/deis).

## DEISMultistepScheduler
[[autodoc]] DEISMultistepScheduler
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
else:
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
Expand Down
481 changes: 481 additions & 0 deletions src/diffusers/schedulers/scheduling_deis_multistep.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,15 @@ def __init__(

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if algorithm_type == "deis":
algorithm_type = "dpmsolver++"
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
if solver_type == "logrho":
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")

# setable values
self.num_inference_steps = None
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,15 @@ def __init__(

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if algorithm_type == "deis":
algorithm_type = "dpmsolver++"
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
if solver_type == "logrho":
solver_type = "midpoint"
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")

# setable values
self.num_inference_steps = None
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,7 @@
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverSinglestepScheduler",
"KDPM2DiscreteScheduler",
"KDPM2AncestralDiscreteScheduler",
"DEISMultistepScheduler",
]
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DEISMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
202 changes: 202 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
Expand Down Expand Up @@ -2505,6 +2506,207 @@ def test_full_loop_device(self):
assert abs(result_mean.item() - 0.0266) < 1e-3


class DEISMultistepSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DEISMultistepScheduler,)
forward_default_kwargs = (("num_inference_steps", 25),)

def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
}

config.update(**kwargs)
return config

def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]

with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

def test_from_save_pretrained(self):
pass

def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)

# copy over dummy past residuals (must be after setting timesteps)
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]

with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)

# copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]

output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)

for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample

return sample

def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)

num_inference_steps = kwargs.pop("num_inference_steps", None)

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

sample = self.dummy_sample
residual = 0.1 * sample

if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps

# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]

time_step_0 = scheduler.timesteps[5]
time_step_1 = scheduler.timesteps[6]

output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample

self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)

def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)

def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
for solver_type in ["logrho"]:
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
algorithm_type="deis",
solver_order=order,
solver_type=solver_type,
)

def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)

def test_solver_order_and_type(self):
for algorithm_type in ["deis"]:
for solver_type in ["logrho"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
assert not torch.isnan(sample).any(), "Samples have nan numbers"

def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)

def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)

def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.23916) < 1e-3

def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.091) < 1e-3

def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
scheduler = scheduler_class(**scheduler_config)

num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter.half()
scheduler.set_timesteps(num_inference_steps)

for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample

assert sample.dtype == torch.float16


class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (KDPM2AncestralDiscreteScheduler,)
num_inference_steps = 10
Expand Down

0 comments on commit be99201

Please sign in to comment.