Skip to content

Commit 84f9759

Browse files
Add some warnings and prevent crash when cond devices don't match. (Comfy-Org#9169)
1 parent 7991341 commit 84f9759

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

comfy/conds.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import math
33
import comfy.utils
4+
import logging
45

56

67
class CONDRegular:
@@ -16,6 +17,9 @@ def process_cond(self, batch_size, **kwargs):
1617
def can_concat(self, other):
1718
if self.cond.shape != other.cond.shape:
1819
return False
20+
if self.cond.device != other.cond.device:
21+
logging.warning("WARNING: conds not on same device, skipping concat.")
22+
return False
1923
return True
2024

2125
def concat(self, others):
@@ -51,6 +55,9 @@ def can_concat(self, other):
5155
diff = mult_min // min(s1[1], s2[1])
5256
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
5357
return False
58+
if self.cond.device != other.cond.device:
59+
logging.warning("WARNING: conds not on same device: skipping concat.")
60+
return False
5461
return True
5562

5663
def concat(self, others):

comfy/model_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor):
409409
if "unclip_conditioning" in args:
410410
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
411411
else:
412-
return args["pooled_output"].to(device=args["device"])
412+
return args["pooled_output"]
413413

414414
class SDXLRefiner(BaseModel):
415415
def __init__(self, model_config, model_type=ModelType.EPS, device=None):

0 commit comments

Comments
 (0)