Skip to content

Commit

Permalink
Optimize some unneeded if conditions in the sampling code.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 15, 2024
1 parent 0ec513d commit f2e844e
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]

if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
Expand Down Expand Up @@ -170,7 +175,13 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]

sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
Expand Down Expand Up @@ -199,8 +210,13 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]

if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
Expand Down

0 comments on commit f2e844e

Please sign in to comment.