Skip to content

Commit

Permalink
Add back-compatibility to LMS timesteps (open-mmlab#750)
Browse files Browse the repository at this point in the history
* Add back-compatibility to LMS timesteps

* style
  • Loading branch information
anton-l authored Oct 6, 2022
1 parent c119dc4 commit df9c070
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,6 @@ def step(
When returning a tuple, the first element is the sample tensor.
"""
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
warnings.warn(
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
"Make sure to pass one of the `scheduler.timesteps`"
)
if not self.is_scale_input_called:
warnings.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
Expand All @@ -215,7 +210,18 @@ def step(

if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
warnings.warn(
"Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version"
" 0.5.0. Make sure to pass one of the `scheduler.timesteps`."
)
step_index = timestep
else:
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
Expand Down Expand Up @@ -250,7 +256,14 @@ def add_noise(
sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
warnings.warn(
"Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in"
" version 0.5.0. Make sure to pass values from `scheduler.timesteps`."
)
step_indices = timesteps
else:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down

0 comments on commit df9c070

Please sign in to comment.