diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index f9b281894d4..5bb991e76a3 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -129,8 +129,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - sigma_hat = sigmas[i] * (gamma + 1) + if s_churn > 0: + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + sigma_hat = sigmas[i] * (gamma + 1) + else: + gamma = 0 + sigma_hat = sigmas[i] + if gamma > 0: eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 @@ -170,7 +175,13 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + if s_churn > 0: + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + sigma_hat = sigmas[i] * (gamma + 1) + else: + gamma = 0 + sigma_hat = sigmas[i] + sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: eps = torch.randn_like(x) * s_noise @@ -199,8 +210,13 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - sigma_hat = sigmas[i] * (gamma + 1) + if s_churn > 0: + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + sigma_hat = sigmas[i] * (gamma + 1) + else: + gamma = 0 + sigma_hat = sigmas[i] + if gamma > 0: eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5