Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 286 additions & 0 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,292 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x

@torch.no_grad()
def sample_euler_ancestral_multipass(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
pass_steps=2,
pass_sigma_max=float("inf"),
pass_sigma_min=12.0,
):
"""
A multipass variant of Euler-Ancestral sampling.

- For each i in [0, len(sigmas)-2], we check if sigma_i is in [pass_sigma_min, pass_sigma_max].
If so, subdivide the step from sigma_i -> sigma_{i+1} into 'pass_steps' sub-steps.
Otherwise, do a single standard step.
- Each sub-step calls 'get_ancestral_step(...)' with the sub-interval's start & end sigmas,
then applies the usual Euler-Ancestral update:
x = x + d*dt + (noise * sigma_up)
"""
if extra_args is None:
extra_args = {}

seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler

s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_i = sigmas[i]
sigma_ip1 = sigmas[i + 1]

# Decide how many sub-steps to do
if pass_sigma_min <= sigma_i <= pass_sigma_max:
n_sub = pass_steps
else:
n_sub = 1
sub_sigmas = torch.linspace(sigma_i, sigma_ip1, n_sub + 1, device=sigmas.device)

for sub_step in range(n_sub):
# Current sub-step range:
sub_sigma_curr = sub_sigmas[sub_step]
sub_sigma_next = sub_sigmas[sub_step + 1]

# Denoise at the current sub-sigma
denoised = model(x, sub_sigma_curr * s_in, **extra_args)

if callback is not None:
callback({
'x': x,
'i': i,
'sub_step': sub_step,
'sigma': sub_sigma_curr,
'denoised': denoised
})

# Compute the ancestral step parameters for this sub-interval
sigma_down, sigma_up = get_ancestral_step(sub_sigma_curr, sub_sigma_next, eta=eta)
if sigma_down == 0.0:
# If we're stepping down to 0, we typically just take the final denoised
x = denoised
else:
# Normal Euler-Ancestral logic
d = to_d(x, sub_sigma_curr, denoised)
dt = sigma_down - sub_sigma_curr
x = x + d * dt
if sigma_up != 0.0:
# Add noise for the "ancestral" part
x = x + noise_sampler(sub_sigma_curr, sub_sigma_next) * (s_noise * sigma_up)

return x

@torch.no_grad()
def sample_euler_multipass(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.,
s_tmin=0.,
s_tmax=float('inf'),
s_noise=1.0,
pass_steps=2,
pass_sigma_max=float("inf"),
pass_sigma_min=12.0,
):
"""
A multipass variant of Euler sampling.
"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_i = sigmas[i]
sigma_ip1 = sigmas[i + 1]

# Decide how many sub-steps to do
if pass_sigma_min <= sigma_i <= pass_sigma_max:
n_sub = pass_steps
else:
n_sub = 1
sub_sigmas = torch.linspace(sigma_i, sigma_ip1, n_sub + 1, device=sigmas.device)

for sub_step in range(n_sub):
# Current sub-step range:
sub_sigma_curr = sub_sigmas[sub_step]
sub_sigma_next = sub_sigmas[sub_step + 1]

if s_churn > 0:
gamma = min(s_churn / (n_sub - 1), 2 ** 0.5 - 1) if s_tmin <= sub_sigma_curr < s_tmax else 0
sigma_hat = sub_sigma_curr * (gamma + 1)
else:
gamma = 0
sigma_hat = sub_sigma_curr

if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigma_hat ** 2) ** 0.5

# Denoise at the current sub-sigma
denoised = model(x, sub_sigma_curr * s_in, **extra_args)

if callback is not None:
callback({
'x': x,
'i': i,
'sub_step': sub_step,
'sigma': sub_sigma_curr,
'sigma_hat': sigma_hat,
'denoised': denoised,
})

d = to_d(x, sigma_hat, denoised)
dt = sub_sigma_next - sigma_hat
# Euler method
x = x + d * dt
return x

@torch.no_grad()
def sample_euler_multipass_cfg_pp(
model, x, sigmas,
extra_args=None,
callback=None,
disable=None,
s_noise=1.0,
s_churn=0.,
s_tmin=0.,
s_tmax=float('inf'),
noise_sampler=None,
pass_steps=2,
pass_sigma_max=float("inf"),
pass_sigma_min=12.0,
):
"""
CFG++-enabled multipass Euler sampler.
"""
if extra_args is None:
extra_args = {}
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler

# CFG++ wrapper
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=True
)

s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_i = sigmas[i]
sigma_ip1 = sigmas[i + 1]

n_sub = pass_steps if pass_sigma_min <= sigma_i <= pass_sigma_max else 1
sub_sigmas = torch.linspace(sigma_i, sigma_ip1, n_sub + 1, device=sigmas.device)

for sub_step in range(n_sub):
sub_sigma = sub_sigmas[sub_step]
sub_sigma_next = sub_sigmas[sub_step + 1]

if s_churn > 0:
gamma = min(s_churn / (n_sub - 1), 2 ** 0.5 - 1) if s_tmin <= sub_sigma < s_tmax else 0
sigma_hat = sub_sigma * (gamma + 1)
else:
gamma = 0
sigma_hat = sub_sigma

if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sub_sigma ** 2).sqrt()

denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, temp[0])
dt = sub_sigma_next - sigma_hat

x = x + (denoised - temp[0]) + d * dt

if callback is not None:
callback({
'x': x, 'i': i, 'sub_step': sub_step,
'sigma': sub_sigma, 'sigma_hat': sigma_hat,
'denoised': denoised, 'uncond_denoised': temp[0]
})

return x

@torch.no_grad()
def sample_euler_multipass_ancestral_cfg_pp(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
eta=1.0,
s_noise=1.0,
noise_sampler=None,
pass_steps=2,
pass_sigma_max=float("inf"),
pass_sigma_min=12.0,
):
"""
CFG++-enabled multipass ancestral Euler sampler.
"""
if extra_args is None:
extra_args = {}
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler

# CFG++ wrapper
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(
model_options, post_cfg_function, disable_cfg1_optimization=True
)

s_in = x.new_ones([x.shape[0]])

for i in trange(len(sigmas) - 1, disable=disable):
sigma_i = sigmas[i]
sigma_ip1 = sigmas[i + 1]

# Subdivision
n_sub = pass_steps if pass_sigma_min <= sigma_i <= pass_sigma_max else 1
sub_sigmas = torch.linspace(sigma_i, sigma_ip1, n_sub + 1, device=sigmas.device)

for sub_step in range(n_sub):
sub_sigma = sub_sigmas[sub_step]
sub_sigma_next = sub_sigmas[sub_step + 1]

# Compute ancestral steps
sigma_down, sigma_up = get_ancestral_step(sub_sigma, sub_sigma_next, eta=eta)

# CFG++ denoise
denoised = model(x, sub_sigma * s_in, **extra_args)
d = to_d(x, sub_sigma, temp[0])
dt = sigma_down - sub_sigma

# Main ancestral Euler update with CFG++
x = x + (denoised - temp[0]) + d * dt

# Noise injection
if sub_sigma_next > 0:
x = x + noise_sampler(sub_sigma, sub_sigma_next) * s_noise * sigma_up

if callback is not None:
callback({
'x': x, 'i': i, 'sub_step': sub_step,
'sigma': sub_sigma, 'sigma_hat': sub_sigma,
'denoised': denoised, 'uncond_denoised': temp[0]
})

return x

@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
Expand Down
4 changes: 3 additions & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,9 @@ def max_denoise(self, model_wrap, sigmas):
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma

KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp",
"euler_multipass", "euler_multipass_cfg_pp", "euler_ancestral_multipass", "euler_multipass_ancestral_cfg_pp",
"heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
Expand Down
Loading