Skip to content

Commit d927730

Browse files
authored
Initial code for new SLG node (Comfy-Org#8759)
1 parent 34c8eee commit d927730

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

comfy/model_patcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,9 @@ def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_op
379379
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
380380
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
381381

382+
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
383+
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
384+
382385
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
383386
self.model_options["model_function_wrapper"] = unet_wrapper_function
384387

comfy/samplers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,11 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
373373
uncond_ = uncond
374374

375375
conds = [cond, uncond_]
376-
out = calc_cond_batch(model, conds, x, timestep, model_options)
376+
if "sampler_calc_cond_batch_function" in model_options:
377+
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
378+
out = model_options["sampler_calc_cond_batch_function"](args)
379+
else:
380+
out = calc_cond_batch(model, conds, x, timestep, model_options)
377381

378382
for fn in model_options.get("sampler_pre_cfg_function", []):
379383
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,

comfy_extras/nodes_slg.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,75 @@ def post_cfg_function(args):
7878

7979
return (m, )
8080

81+
class SkipLayerGuidanceDiTSimple:
82+
'''
83+
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
84+
'''
85+
@classmethod
86+
def INPUT_TYPES(s):
87+
return {"required": {"model": ("MODEL", ),
88+
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
89+
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
90+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
91+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
92+
}}
93+
RETURN_TYPES = ("MODEL",)
94+
FUNCTION = "skip_guidance"
95+
EXPERIMENTAL = True
96+
97+
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
98+
99+
CATEGORY = "advanced/guidance"
100+
101+
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
102+
def skip(args, extra_args):
103+
return args
104+
105+
model_sampling = model.get_model_object("model_sampling")
106+
sigma_start = model_sampling.percent_to_sigma(start_percent)
107+
sigma_end = model_sampling.percent_to_sigma(end_percent)
108+
109+
double_layers = re.findall(r'\d+', double_layers)
110+
double_layers = [int(i) for i in double_layers]
111+
112+
single_layers = re.findall(r'\d+', single_layers)
113+
single_layers = [int(i) for i in single_layers]
114+
115+
if len(double_layers) == 0 and len(single_layers) == 0:
116+
return (model, )
117+
118+
def calc_cond_batch_function(args):
119+
x = args["input"]
120+
model = args["model"]
121+
conds = args["conds"]
122+
sigma = args["sigma"]
123+
124+
model_options = args["model_options"]
125+
slg_model_options = model_options.copy()
126+
127+
for layer in double_layers:
128+
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "double_block", layer)
129+
130+
for layer in single_layers:
131+
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "single_block", layer)
132+
133+
cond, uncond = conds
134+
sigma_ = sigma[0].item()
135+
if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None:
136+
cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options)
137+
_, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options)
138+
out = [cond_out, uncond_out]
139+
else:
140+
out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options)
141+
142+
return out
143+
144+
m = model.clone()
145+
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
146+
147+
return (m, )
81148

82149
NODE_CLASS_MAPPINGS = {
83150
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
151+
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
84152
}

0 commit comments

Comments
 (0)