@@ -178,10 +178,7 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
178178 xc = torch .cat ([xc ] + [comfy .model_management .cast_to_device (c_concat , xc .device , xc .dtype )], dim = 1 )
179179
180180 context = c_crossattn
181- dtype = self .get_dtype ()
182-
183- if self .manual_cast_dtype is not None :
184- dtype = self .manual_cast_dtype
181+ dtype = self .get_dtype_inference ()
185182
186183 xc = xc .to (dtype )
187184 device = xc .device
@@ -218,6 +215,13 @@ def process_timestep(self, timestep, **kwargs):
218215 def get_dtype (self ):
219216 return self .diffusion_model .dtype
220217
218+ def get_dtype_inference (self ):
219+ dtype = self .get_dtype ()
220+
221+ if self .manual_cast_dtype is not None :
222+ dtype = self .manual_cast_dtype
223+ return dtype
224+
221225 def encode_adm (self , ** kwargs ):
222226 return None
223227
@@ -372,9 +376,7 @@ def memory_required(self, input_shape, cond_shapes={}):
372376 input_shapes += shape
373377
374378 if comfy .model_management .xformers_enabled () or comfy .model_management .pytorch_attention_flash_attention ():
375- dtype = self .get_dtype ()
376- if self .manual_cast_dtype is not None :
377- dtype = self .manual_cast_dtype
379+ dtype = self .get_dtype_inference ()
378380 #TODO: this needs to be tweaked
379381 area = sum (map (lambda input_shape : input_shape [0 ] * math .prod (input_shape [2 :]), input_shapes ))
380382 return (area * comfy .model_management .dtype_size (dtype ) * 0.01 * self .memory_usage_factor ) * (1024 * 1024 )
@@ -1165,7 +1167,7 @@ def extra_conds(self, **kwargs):
11651167 t5xxl_ids = t5xxl_ids .unsqueeze (0 )
11661168
11671169 if torch .is_inference_mode_enabled (): # if not we are training
1168- cross_attn = self .diffusion_model .preprocess_text_embeds (cross_attn .to (device = device , dtype = self .get_dtype ()), t5xxl_ids .to (device = device ), t5xxl_weights = t5xxl_weights .to (device = device , dtype = self .get_dtype ()))
1170+ cross_attn = self .diffusion_model .preprocess_text_embeds (cross_attn .to (device = device , dtype = self .get_dtype_inference ()), t5xxl_ids .to (device = device ), t5xxl_weights = t5xxl_weights .to (device = device , dtype = self .get_dtype_inference ()))
11691171 else :
11701172 out ['t5xxl_ids' ] = comfy .conds .CONDRegular (t5xxl_ids )
11711173 out ['t5xxl_weights' ] = comfy .conds .CONDRegular (t5xxl_weights )
0 commit comments