Skip to content

Commit 479a427

Browse files
authored
Add dpmpp_2m_cfg_pp (#4992)
1 parent 3a0eeee commit 479a427

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,3 +1154,36 @@ def post_cfg_function(args):
11541154
if sigmas[i + 1] > 0:
11551155
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
11561156
return x
1157+
1158+
@torch.no_grad()
1159+
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1160+
"""DPM-Solver++(2M)."""
1161+
extra_args = {} if extra_args is None else extra_args
1162+
s_in = x.new_ones([x.shape[0]])
1163+
t_fn = lambda sigma: sigma.log().neg()
1164+
1165+
old_uncond_denoised = None
1166+
uncond_denoised = None
1167+
def post_cfg_function(args):
1168+
nonlocal uncond_denoised
1169+
uncond_denoised = args["uncond_denoised"]
1170+
return args["denoised"]
1171+
1172+
model_options = extra_args.get("model_options", {}).copy()
1173+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1174+
1175+
for i in trange(len(sigmas) - 1, disable=disable):
1176+
denoised = model(x, sigmas[i] * s_in, **extra_args)
1177+
if callback is not None:
1178+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1179+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
1180+
h = t_next - t
1181+
if old_uncond_denoised is None or sigmas[i + 1] == 0:
1182+
denoised_mix = -torch.exp(-h) * uncond_denoised
1183+
else:
1184+
h_last = t - t_fn(sigmas[i - 1])
1185+
r = h_last / h
1186+
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
1187+
x = denoised + denoised_mix + torch.exp(-h) * x
1188+
old_uncond_denoised = uncond_denoised
1189+
return x

comfy/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def max_denoise(self, model_wrap, sigmas):
571571

572572
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
573573
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
574-
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
574+
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
575575
"ipndm", "ipndm_v", "deis"]
576576

577577
class KSAMPLER(Sampler):

0 commit comments

Comments
 (0)