Skip to content

Commit

Permalink
Add K-LMS scheduler from k-diffusion (#185)
Browse files Browse the repository at this point in the history
* test LMS with LDM

* test LMS with LDM

* Interchangeable sigma and timestep. Added dummy objects

* Debug

* cuda generator

* Fix derivatives

* Update tests

* Rename Lms->LMS
  • Loading branch information
anton-l authored Aug 16, 2022
1 parent 9070c39 commit d7b6920
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 7 deletions.
8 changes: 7 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available


__version__ = "0.1.3"
Expand All @@ -27,11 +27,17 @@
SchedulerMixin,
ScoreSdeVeScheduler,
)


if is_scipy_available():
from .schedulers import LMSDiscreteScheduler

from .training_utils import EMAModel


if is_transformers_available():
from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline


else:
from .utils.dummy_transformers_objects import *
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler


class StableDiffusionPipeline(DiffusionPipeline):
Expand All @@ -18,7 +18,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
Expand Down Expand Up @@ -105,9 +105,16 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

for i, t in tqdm(enumerate(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
Expand All @@ -118,7 +125,10 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_scipy_available
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin


if is_scipy_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
134 changes: 134 additions & 0 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import numpy as np
import torch

from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
tensor_format="pt",
):
"""
Linear Multistep Scheduler for discrete beta schedules.
Based on the original k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
"""

if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)

self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.derivatives = []

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

def get_lms_coefficient(self, order, t, current_order):
"""
Compute a linear multistep coefficient
"""

def lms_derivative(tau):
prod = 1.0
for k in range(order):
if current_order == k:
continue
prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
return prod

integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

return integrated_coeff

def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)

low_idx = np.floor(self.timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int)
frac = np.mod(self.timesteps, 1.0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
self.sigmas = np.concatenate([sigmas, [0.0]])

self.derivatives = []

self.set_format(tensor_format=self.tensor_format)

def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
):
sigma = self.sigmas[timestep]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output

# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
self.derivatives.append(derivative)
if len(self.derivatives) > order:
self.derivatives.pop(0)

# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]

# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
)

return {"prev_sample": prev_sample}

def add_noise(self, original_samples, noise, timesteps):
alpha_prod = self.alphas_cumprod[timesteps]
alpha_prod = self.match_shape(alpha_prod, original_samples)

noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise
return noisy_samples

def __len__(self):
return self.config.num_train_timesteps
19 changes: 19 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@
_modelcards_available = False


_scipy_available = importlib.util.find_spec("scipy") is not None
try:
_scipy_version = importlib_metadata.version("scipy")
logger.debug(f"Successfully imported transformers version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
_scipy_available = False


def is_transformers_available():
return _transformers_available

Expand All @@ -85,6 +93,10 @@ def is_modelcards_available():
return _modelcards_available


def is_scipy_available():
return _scipy_available


class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
Expand Down Expand Up @@ -118,11 +130,18 @@ class RevisionNotFoundError(HTTPError):
"""


SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
scipy`
"""


BACKENDS_MAPPING = OrderedDict(
[
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
]
)

Expand Down
24 changes: 24 additions & 0 deletions src/diffusers/utils/dummy_scipy_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends


class LmsDiscreteScheduler(metaclass=DummyObject):
_backends = ["scipy"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])


class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["scipy"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])


class StableDiffusionPipeline(metaclass=DummyObject):

This comment has been minimized.

Copy link
@pcuenca

pcuenca Aug 16, 2022

Member

Why do we need to include the pipeline itself? I think I'm not fully following the logic here.

_backends = ["scipy"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])
24 changes: 22 additions & 2 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
KarrasVeScheduler,
LDMPipeline,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
Expand Down Expand Up @@ -841,7 +842,7 @@ def test_ldm_text2img_fast(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
Expand All @@ -862,7 +863,7 @@ def test_stable_diffusion(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

Expand Down Expand Up @@ -977,3 +978,22 @@ def test_karras_ve_pipeline(self):
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1-diffusers"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
pipe.scheduler = scheduler

prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
"sample"
]

image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

0 comments on commit d7b6920

Please sign in to comment.