Skip to content

Commit

Permalink
[Scheduler] fix: EDM schedulers when using the exp sigma schedule. (h…
Browse files Browse the repository at this point in the history
…uggingface#8385)

* fix: euledm when using the exp sigma schedule.

* fix-copies

* remove print.

* reduce friction

* yiyi's suggestioms
  • Loading branch information
sayakpaul authored Jun 5, 2024
1 parent 2f6f426 commit 48207d6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc

self.num_inference_steps = num_inference_steps

ramp = np.linspace(0, 1, self.num_inference_steps)
ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)

sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)

if self.config.final_sigmas_type == "sigma_min":
Expand Down Expand Up @@ -283,7 +283,6 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho

return sigmas

# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
Expand Down
6 changes: 2 additions & 4 deletions src/diffusers/schedulers/scheduling_edm_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -210,13 +209,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
"""
self.num_inference_steps = num_inference_steps

ramp = np.linspace(0, 1, self.num_inference_steps)
ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)

sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)

self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
Expand All @@ -234,7 +233,6 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho

return sigmas

def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
Expand Down

0 comments on commit 48207d6

Please sign in to comment.