@@ -99,7 +99,7 @@ def alpha_bar_fn(t):
9999
100100
101101# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
102- def rescale_zero_terminal_snr (betas ) :
102+ def rescale_zero_terminal_snr (betas : torch . Tensor ) -> torch . Tensor :
103103 """
104104 Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
105105
@@ -187,14 +187,14 @@ def __init__(
187187 num_train_timesteps : int = 1000 ,
188188 beta_start : float = 0.0001 ,
189189 beta_end : float = 0.02 ,
190- beta_schedule : str = "linear" ,
190+ beta_schedule : Literal [ "linear" , "scaled_linear" , "squaredcos_cap_v2" ] = "linear" ,
191191 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
192192 clip_sample : bool = True ,
193193 set_alpha_to_one : bool = True ,
194194 steps_offset : int = 0 ,
195- prediction_type : str = "epsilon" ,
195+ prediction_type : Literal [ "epsilon" , "sample" , "v_prediction" ] = "epsilon" ,
196196 clip_sample_range : float = 1.0 ,
197- timestep_spacing : str = "leading" ,
197+ timestep_spacing : Literal [ "leading" , "trailing" ] = "leading" ,
198198 rescale_betas_zero_snr : bool = False ,
199199 ** kwargs ,
200200 ):
@@ -210,7 +210,15 @@ def __init__(
210210 self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
211211 elif beta_schedule == "scaled_linear" :
212212 # this schedule is very specific to the latent diffusion model.
213- self .betas = torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
213+ self .betas = (
214+ torch .linspace (
215+ beta_start ** 0.5 ,
216+ beta_end ** 0.5 ,
217+ num_train_timesteps ,
218+ dtype = torch .float32 ,
219+ )
220+ ** 2
221+ )
214222 elif beta_schedule == "squaredcos_cap_v2" :
215223 # Glide cosine schedule
216224 self .betas = betas_for_alpha_bar (num_train_timesteps )
@@ -256,7 +264,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
256264 """
257265 return sample
258266
259- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
267+ def set_timesteps (
268+ self ,
269+ num_inference_steps : int ,
270+ device : Optional [Union [str , torch .device ]] = None ,
271+ ) -> None :
260272 """
261273 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
262274
@@ -308,20 +320,10 @@ def step(
308320 Args:
309321 model_output (`torch.Tensor`):
310322 The direct output from learned diffusion model.
311- timestep (`float `):
323+ timestep (`int `):
312324 The current discrete timestep in the diffusion chain.
313325 sample (`torch.Tensor`):
314326 A current instance of a sample created by the diffusion process.
315- eta (`float`):
316- The weight of noise for added noise in diffusion step.
317- use_clipped_model_output (`bool`, defaults to `False`):
318- If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
319- because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
320- clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
321- `use_clipped_model_output` has no effect.
322- variance_noise (`torch.Tensor`):
323- Alternative to generating noise with `generator` by directly providing the noise for the variance
324- itself. Useful for methods such as [`CycleDiffusion`].
325327 return_dict (`bool`, *optional*, defaults to `True`):
326328 Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
327329 `tuple`.
@@ -335,7 +337,8 @@ def step(
335337 # 1. get previous step value (=t+1)
336338 prev_timestep = timestep
337339 timestep = min (
338- timestep - self .config .num_train_timesteps // self .num_inference_steps , self .config .num_train_timesteps - 1
340+ timestep - self .config .num_train_timesteps // self .num_inference_steps ,
341+ self .config .num_train_timesteps - 1 ,
339342 )
340343
341344 # 2. compute alphas, betas
@@ -378,5 +381,5 @@ def step(
378381 return (prev_sample , pred_original_sample )
379382 return DDIMSchedulerOutput (prev_sample = prev_sample , pred_original_sample = pred_original_sample )
380383
381- def __len__ (self ):
384+ def __len__ (self ) -> int :
382385 return self .config .num_train_timesteps
0 commit comments