Skip to content

Commit d4f97d1

Browse files
authored
Improve docstrings and type hints in scheduling_ddim_inverse.py (#13020)
docs: improve docstring scheduling_ddim_inverse.py
1 parent 1d32b19 commit d4f97d1

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

src/diffusers/schedulers/scheduling_ddim_inverse.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def alpha_bar_fn(t):
9999

100100

101101
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
102-
def rescale_zero_terminal_snr(betas):
102+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
103103
"""
104104
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
105105
@@ -187,14 +187,14 @@ def __init__(
187187
num_train_timesteps: int = 1000,
188188
beta_start: float = 0.0001,
189189
beta_end: float = 0.02,
190-
beta_schedule: str = "linear",
190+
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
191191
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
192192
clip_sample: bool = True,
193193
set_alpha_to_one: bool = True,
194194
steps_offset: int = 0,
195-
prediction_type: str = "epsilon",
195+
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
196196
clip_sample_range: float = 1.0,
197-
timestep_spacing: str = "leading",
197+
timestep_spacing: Literal["leading", "trailing"] = "leading",
198198
rescale_betas_zero_snr: bool = False,
199199
**kwargs,
200200
):
@@ -210,7 +210,15 @@ def __init__(
210210
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
211211
elif beta_schedule == "scaled_linear":
212212
# this schedule is very specific to the latent diffusion model.
213-
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
213+
self.betas = (
214+
torch.linspace(
215+
beta_start**0.5,
216+
beta_end**0.5,
217+
num_train_timesteps,
218+
dtype=torch.float32,
219+
)
220+
** 2
221+
)
214222
elif beta_schedule == "squaredcos_cap_v2":
215223
# Glide cosine schedule
216224
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -256,7 +264,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
256264
"""
257265
return sample
258266

259-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
267+
def set_timesteps(
268+
self,
269+
num_inference_steps: int,
270+
device: Optional[Union[str, torch.device]] = None,
271+
) -> None:
260272
"""
261273
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
262274
@@ -308,20 +320,10 @@ def step(
308320
Args:
309321
model_output (`torch.Tensor`):
310322
The direct output from learned diffusion model.
311-
timestep (`float`):
323+
timestep (`int`):
312324
The current discrete timestep in the diffusion chain.
313325
sample (`torch.Tensor`):
314326
A current instance of a sample created by the diffusion process.
315-
eta (`float`):
316-
The weight of noise for added noise in diffusion step.
317-
use_clipped_model_output (`bool`, defaults to `False`):
318-
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
319-
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
320-
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
321-
`use_clipped_model_output` has no effect.
322-
variance_noise (`torch.Tensor`):
323-
Alternative to generating noise with `generator` by directly providing the noise for the variance
324-
itself. Useful for methods such as [`CycleDiffusion`].
325327
return_dict (`bool`, *optional*, defaults to `True`):
326328
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
327329
`tuple`.
@@ -335,7 +337,8 @@ def step(
335337
# 1. get previous step value (=t+1)
336338
prev_timestep = timestep
337339
timestep = min(
338-
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
340+
timestep - self.config.num_train_timesteps // self.num_inference_steps,
341+
self.config.num_train_timesteps - 1,
339342
)
340343

341344
# 2. compute alphas, betas
@@ -378,5 +381,5 @@ def step(
378381
return (prev_sample, pred_original_sample)
379382
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
380383

381-
def __len__(self):
384+
def __len__(self) -> int:
382385
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)