16
16
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
17
17
18
18
import math
19
- from typing import List , Optional , Tuple , Union , Callable
20
-
21
19
import numpy as np
22
20
import torch
23
-
21
+ from typing import List , Optional , Tuple , Union , Callable
24
22
from diffusers .configuration_utils import ConfigMixin , register_to_config
25
23
from diffusers .utils .torch_utils import randn_tensor
26
24
from diffusers .schedulers .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
27
25
28
26
29
27
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
28
def betas_for_alpha_bar (
31
- num_diffusion_timesteps ,
32
- max_beta = 0.999 ,
33
- alpha_transform_type = "cosine" ,
29
+ num_diffusion_timesteps ,
30
+ max_beta = 0.999 ,
31
+ alpha_transform_type = "cosine" ,
34
32
):
35
33
"""
36
34
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -275,33 +273,32 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
275
273
https://arxiv.org/abs/2205.11487
276
274
"""
277
275
dtype = sample .dtype
278
- batch_size , channels , height , width = sample .shape
276
+ batch_size , channels , * remaining_dims = sample .shape
279
277
280
278
if dtype not in (torch .float32 , torch .float64 ):
281
279
sample = sample .float () # upcast for quantile calculation, and clamp not implemented for cpu half
282
280
283
281
# Flatten sample for doing quantile calculation along each image
284
- sample = sample .reshape (batch_size , channels * height * width )
282
+ sample = sample .reshape (batch_size , channels * np . prod ( remaining_dims ) )
285
283
286
284
abs_sample = sample .abs () # "a certain percentile absolute pixel value"
287
285
288
286
s = torch .quantile (abs_sample , self .config .dynamic_thresholding_ratio , dim = 1 )
289
287
s = torch .clamp (
290
288
s , min = 1 , max = self .config .sample_max_value
291
289
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
292
-
293
290
s = s .unsqueeze (1 ) # (batch_size, 1) because clamp will broadcast along dim=0
294
291
sample = torch .clamp (sample , - s , s ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
295
292
296
- sample = sample .reshape (batch_size , channels , height , width )
293
+ sample = sample .reshape (batch_size , channels , * remaining_dims )
297
294
sample = sample .to (dtype )
298
295
299
296
return sample
300
297
301
298
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
302
299
def _sigma_to_t (self , sigma , log_sigmas ):
303
300
# get log sigma
304
- log_sigma = np .log (sigma )
301
+ log_sigma = np .log (np . maximum ( sigma , 1e-10 ) )
305
302
306
303
# get distribution
307
304
dists = log_sigma - log_sigmas [:, np .newaxis ]
@@ -326,8 +323,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
326
323
def _convert_to_karras (self , in_sigmas : torch .FloatTensor , num_inference_steps ) -> torch .FloatTensor :
327
324
"""Constructs the noise schedule of Karras et al. (2022)."""
328
325
329
- sigma_min : float = in_sigmas [- 1 ].item ()
330
- sigma_max : float = in_sigmas [0 ].item ()
326
+ # Hack to make sure that other schedulers which copy this function don't break
327
+ # TODO: Add this logic to the other schedulers
328
+ if hasattr (self .config , "sigma_min" ):
329
+ sigma_min = self .config .sigma_min
330
+ else :
331
+ sigma_min = None
332
+
333
+ if hasattr (self .config , "sigma_max" ):
334
+ sigma_max = self .config .sigma_max
335
+ else :
336
+ sigma_max = None
337
+
338
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
339
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
331
340
332
341
rho = 7.0 # 7.0 is the value used in the paper
333
342
ramp = np .linspace (0 , 1 , num_inference_steps )
@@ -832,10 +841,10 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
832
841
833
842
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
834
843
def add_noise (
835
- self ,
836
- original_samples : torch .FloatTensor ,
837
- noise : torch .FloatTensor ,
838
- timesteps : torch .IntTensor ,
844
+ self ,
845
+ original_samples : torch .FloatTensor ,
846
+ noise : torch .FloatTensor ,
847
+ timesteps : torch .IntTensor ,
839
848
) -> torch .FloatTensor :
840
849
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
841
850
alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device , dtype = original_samples .dtype )
@@ -855,4 +864,4 @@ def add_noise(
855
864
return noisy_samples
856
865
857
866
def __len__ (self ):
858
- return self .config .num_train_timesteps
867
+ return self .config .num_train_timesteps
0 commit comments