Skip to content

Commit b648ea4

Browse files
authored
fix fail cases for tests (#3)
* fix bugs in repository consistency
1 parent b168357 commit b648ea4

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

src/diffusers/schedulers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
6666
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
6767
_import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
68+
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
6869

6970
try:
7071
if not is_flax_available():
@@ -155,6 +156,7 @@
155156
from .scheduling_unipc_multistep import UniPCMultistepScheduler
156157
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
157158
from .scheduling_vq_diffusion import VQDiffusionScheduler
159+
from .scheduling_sasolver import SASolverScheduler
158160

159161
try:
160162
if not is_flax_available():

src/diffusers/schedulers/scheduling_sasolver.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,19 @@
1616
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
1717

1818
import math
19-
from typing import List, Optional, Tuple, Union, Callable
20-
2119
import numpy as np
2220
import torch
23-
21+
from typing import List, Optional, Tuple, Union, Callable
2422
from diffusers.configuration_utils import ConfigMixin, register_to_config
2523
from diffusers.utils.torch_utils import randn_tensor
2624
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2725

2826

2927
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3028
def betas_for_alpha_bar(
31-
num_diffusion_timesteps,
32-
max_beta=0.999,
33-
alpha_transform_type="cosine",
29+
num_diffusion_timesteps,
30+
max_beta=0.999,
31+
alpha_transform_type="cosine",
3432
):
3533
"""
3634
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -275,33 +273,32 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
275273
https://arxiv.org/abs/2205.11487
276274
"""
277275
dtype = sample.dtype
278-
batch_size, channels, height, width = sample.shape
276+
batch_size, channels, *remaining_dims = sample.shape
279277

280278
if dtype not in (torch.float32, torch.float64):
281279
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
282280

283281
# Flatten sample for doing quantile calculation along each image
284-
sample = sample.reshape(batch_size, channels * height * width)
282+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
285283

286284
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
287285

288286
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
289287
s = torch.clamp(
290288
s, min=1, max=self.config.sample_max_value
291289
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
292-
293290
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
294291
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
295292

296-
sample = sample.reshape(batch_size, channels, height, width)
293+
sample = sample.reshape(batch_size, channels, *remaining_dims)
297294
sample = sample.to(dtype)
298295

299296
return sample
300297

301298
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
302299
def _sigma_to_t(self, sigma, log_sigmas):
303300
# get log sigma
304-
log_sigma = np.log(sigma)
301+
log_sigma = np.log(np.maximum(sigma, 1e-10))
305302

306303
# get distribution
307304
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -326,8 +323,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
326323
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
327324
"""Constructs the noise schedule of Karras et al. (2022)."""
328325

329-
sigma_min: float = in_sigmas[-1].item()
330-
sigma_max: float = in_sigmas[0].item()
326+
# Hack to make sure that other schedulers which copy this function don't break
327+
# TODO: Add this logic to the other schedulers
328+
if hasattr(self.config, "sigma_min"):
329+
sigma_min = self.config.sigma_min
330+
else:
331+
sigma_min = None
332+
333+
if hasattr(self.config, "sigma_max"):
334+
sigma_max = self.config.sigma_max
335+
else:
336+
sigma_max = None
337+
338+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
339+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
331340

332341
rho = 7.0 # 7.0 is the value used in the paper
333342
ramp = np.linspace(0, 1, num_inference_steps)
@@ -832,10 +841,10 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
832841

833842
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
834843
def add_noise(
835-
self,
836-
original_samples: torch.FloatTensor,
837-
noise: torch.FloatTensor,
838-
timesteps: torch.IntTensor,
844+
self,
845+
original_samples: torch.FloatTensor,
846+
noise: torch.FloatTensor,
847+
timesteps: torch.IntTensor,
839848
) -> torch.FloatTensor:
840849
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
841850
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
@@ -855,4 +864,4 @@ def add_noise(
855864
return noisy_samples
856865

857866
def __len__(self):
858-
return self.config.num_train_timesteps
867+
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)