@@ -1154,3 +1154,36 @@ def post_cfg_function(args):
1154
1154
if sigmas [i + 1 ] > 0 :
1155
1155
x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * sigma_up
1156
1156
return x
1157
+
1158
+ @torch .no_grad ()
1159
+ def sample_dpmpp_2m_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None ):
1160
+ """DPM-Solver++(2M)."""
1161
+ extra_args = {} if extra_args is None else extra_args
1162
+ s_in = x .new_ones ([x .shape [0 ]])
1163
+ t_fn = lambda sigma : sigma .log ().neg ()
1164
+
1165
+ old_uncond_denoised = None
1166
+ uncond_denoised = None
1167
+ def post_cfg_function (args ):
1168
+ nonlocal uncond_denoised
1169
+ uncond_denoised = args ["uncond_denoised" ]
1170
+ return args ["denoised" ]
1171
+
1172
+ model_options = extra_args .get ("model_options" , {}).copy ()
1173
+ extra_args ["model_options" ] = comfy .model_patcher .set_model_options_post_cfg_function (model_options , post_cfg_function , disable_cfg1_optimization = True )
1174
+
1175
+ for i in trange (len (sigmas ) - 1 , disable = disable ):
1176
+ denoised = model (x , sigmas [i ] * s_in , ** extra_args )
1177
+ if callback is not None :
1178
+ callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
1179
+ t , t_next = t_fn (sigmas [i ]), t_fn (sigmas [i + 1 ])
1180
+ h = t_next - t
1181
+ if old_uncond_denoised is None or sigmas [i + 1 ] == 0 :
1182
+ denoised_mix = - torch .exp (- h ) * uncond_denoised
1183
+ else :
1184
+ h_last = t - t_fn (sigmas [i - 1 ])
1185
+ r = h_last / h
1186
+ denoised_mix = - torch .exp (- h ) * uncond_denoised - torch .expm1 (- h ) * (1 / (2 * r )) * (denoised - old_uncond_denoised )
1187
+ x = denoised + denoised_mix + torch .exp (- h ) * x
1188
+ old_uncond_denoised = uncond_denoised
1189
+ return x
0 commit comments