@@ -78,7 +78,75 @@ def post_cfg_function(args):
7878
7979 return (m , )
8080
81+ class SkipLayerGuidanceDiTSimple :
82+ '''
83+ Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
84+ '''
85+ @classmethod
86+ def INPUT_TYPES (s ):
87+ return {"required" : {"model" : ("MODEL" , ),
88+ "double_layers" : ("STRING" , {"default" : "7, 8, 9" , "multiline" : False }),
89+ "single_layers" : ("STRING" , {"default" : "7, 8, 9" , "multiline" : False }),
90+ "start_percent" : ("FLOAT" , {"default" : 0.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 }),
91+ "end_percent" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 }),
92+ }}
93+ RETURN_TYPES = ("MODEL" ,)
94+ FUNCTION = "skip_guidance"
95+ EXPERIMENTAL = True
96+
97+ DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
98+
99+ CATEGORY = "advanced/guidance"
100+
101+ def skip_guidance (self , model , start_percent , end_percent , double_layers = "" , single_layers = "" ):
102+ def skip (args , extra_args ):
103+ return args
104+
105+ model_sampling = model .get_model_object ("model_sampling" )
106+ sigma_start = model_sampling .percent_to_sigma (start_percent )
107+ sigma_end = model_sampling .percent_to_sigma (end_percent )
108+
109+ double_layers = re .findall (r'\d+' , double_layers )
110+ double_layers = [int (i ) for i in double_layers ]
111+
112+ single_layers = re .findall (r'\d+' , single_layers )
113+ single_layers = [int (i ) for i in single_layers ]
114+
115+ if len (double_layers ) == 0 and len (single_layers ) == 0 :
116+ return (model , )
117+
118+ def calc_cond_batch_function (args ):
119+ x = args ["input" ]
120+ model = args ["model" ]
121+ conds = args ["conds" ]
122+ sigma = args ["sigma" ]
123+
124+ model_options = args ["model_options" ]
125+ slg_model_options = model_options .copy ()
126+
127+ for layer in double_layers :
128+ slg_model_options = comfy .model_patcher .set_model_options_patch_replace (slg_model_options , skip , "dit" , "double_block" , layer )
129+
130+ for layer in single_layers :
131+ slg_model_options = comfy .model_patcher .set_model_options_patch_replace (slg_model_options , skip , "dit" , "single_block" , layer )
132+
133+ cond , uncond = conds
134+ sigma_ = sigma [0 ].item ()
135+ if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None :
136+ cond_out , _ = comfy .samplers .calc_cond_batch (model , [cond , None ], x , sigma , model_options )
137+ _ , uncond_out = comfy .samplers .calc_cond_batch (model , [None , uncond ], x , sigma , slg_model_options )
138+ out = [cond_out , uncond_out ]
139+ else :
140+ out = comfy .samplers .calc_cond_batch (model , conds , x , sigma , model_options )
141+
142+ return out
143+
144+ m = model .clone ()
145+ m .set_model_sampler_calc_cond_batch_function (calc_cond_batch_function )
146+
147+ return (m , )
81148
82149NODE_CLASS_MAPPINGS = {
83150 "SkipLayerGuidanceDiT" : SkipLayerGuidanceDiT ,
151+ "SkipLayerGuidanceDiTSimple" : SkipLayerGuidanceDiTSimple ,
84152}
0 commit comments