Skip to content

Commit 1a85483

Browse files
authored
Fix depending on asserts to raise an exception in BatchedBrownianTree and Flash attn module (#9884)
Correctly handle the case where w0 is passed by kwargs in BatchedBrownianTree
1 parent 47a9cde commit 1a85483

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,36 +86,35 @@ class BatchedBrownianTree:
8686
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
8787

8888
def __init__(self, x, t0, t1, seed=None, **kwargs):
89-
self.cpu_tree = True
90-
if "cpu" in kwargs:
91-
self.cpu_tree = kwargs.pop("cpu")
89+
self.cpu_tree = kwargs.pop("cpu", True)
9290
t0, t1, self.sign = self.sort(t0, t1)
93-
w0 = kwargs.get('w0', torch.zeros_like(x))
91+
w0 = kwargs.pop('w0', None)
92+
if w0 is None:
93+
w0 = torch.zeros_like(x)
94+
self.batched = False
9495
if seed is None:
95-
seed = torch.randint(0, 2 ** 63 - 1, []).item()
96-
self.batched = True
97-
try:
98-
assert len(seed) == x.shape[0]
96+
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
97+
elif isinstance(seed, (tuple, list)):
98+
if len(seed) != x.shape[0]:
99+
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
100+
self.batched = True
99101
w0 = w0[0]
100-
except TypeError:
101-
seed = [seed]
102-
self.batched = False
103-
if self.cpu_tree:
104-
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
105102
else:
106-
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
103+
seed = (seed,)
104+
if self.cpu_tree:
105+
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
106+
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
107107

108108
@staticmethod
109109
def sort(a, b):
110110
return (a, b, 1) if a < b else (b, a, -1)
111111

112112
def __call__(self, t0, t1):
113113
t0, t1, sign = self.sort(t0, t1)
114+
device, dtype = t0.device, t0.dtype
114115
if self.cpu_tree:
115-
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
116-
else:
117-
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
118-
116+
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
117+
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
119118
return w if self.batched else w[0]
120119

121120

comfy/ldm/modules/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
600600
mask = mask.unsqueeze(1)
601601

602602
try:
603-
assert mask is None
603+
if mask is not None:
604+
raise RuntimeError("Mask must not be set for Flash attention")
604605
out = flash_attn_wrapper(
605606
q.transpose(1, 2),
606607
k.transpose(1, 2),

0 commit comments

Comments
 (0)