Skip to content

Commit c396531

Browse files
Fix anima preprocess text embeds not using right inference dtype. (Comfy-Org#12501)
1 parent 1892753 commit c396531

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

comfy/model_base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,7 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
178178
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
179179

180180
context = c_crossattn
181-
dtype = self.get_dtype()
182-
183-
if self.manual_cast_dtype is not None:
184-
dtype = self.manual_cast_dtype
181+
dtype = self.get_dtype_inference()
185182

186183
xc = xc.to(dtype)
187184
device = xc.device
@@ -218,6 +215,13 @@ def process_timestep(self, timestep, **kwargs):
218215
def get_dtype(self):
219216
return self.diffusion_model.dtype
220217

218+
def get_dtype_inference(self):
219+
dtype = self.get_dtype()
220+
221+
if self.manual_cast_dtype is not None:
222+
dtype = self.manual_cast_dtype
223+
return dtype
224+
221225
def encode_adm(self, **kwargs):
222226
return None
223227

@@ -372,9 +376,7 @@ def memory_required(self, input_shape, cond_shapes={}):
372376
input_shapes += shape
373377

374378
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
375-
dtype = self.get_dtype()
376-
if self.manual_cast_dtype is not None:
377-
dtype = self.manual_cast_dtype
379+
dtype = self.get_dtype_inference()
378380
#TODO: this needs to be tweaked
379381
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
380382
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@@ -1165,7 +1167,7 @@ def extra_conds(self, **kwargs):
11651167
t5xxl_ids = t5xxl_ids.unsqueeze(0)
11661168

11671169
if torch.is_inference_mode_enabled(): # if not we are training
1168-
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
1170+
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
11691171
else:
11701172
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
11711173
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)

0 commit comments

Comments
 (0)