Skip to content

Commit fe31ad0

Browse files
authored
Add elementwise fusions (Comfy-Org#9495)
* Add elementwise fusions * Add addcmul pattern to Qwen
1 parent ca4e96a commit fe31ad0

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

comfy/ldm/modules/diffusionmodules/mmdit.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def forward(self, x):
109109
def modulate(x, shift, scale):
110110
if shift is None:
111111
shift = torch.zeros_like(scale)
112-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
112+
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
113113

114114

115115
#################################################################################
@@ -564,10 +564,7 @@ def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_
564564
assert not self.pre_only
565565
attn1 = self.attn.post_attention(attn)
566566
attn2 = self.attn2.post_attention(attn2)
567-
out1 = gate_msa.unsqueeze(1) * attn1
568-
out2 = gate_msa2.unsqueeze(1) * attn2
569-
x = x + out1
570-
x = x + out2
567+
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
571568
x = x + gate_mlp.unsqueeze(1) * self.mlp(
572569
modulate(self.norm2(x), shift_mlp, scale_mlp)
573570
)
@@ -594,6 +591,11 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
594591
)
595592
return self.post_attention(attn, *intermediates)
596593

594+
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
595+
out1 = gate_msa.unsqueeze(1) * attn1
596+
out2 = gate_msa2.unsqueeze(1) * attn2
597+
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
598+
return x
597599

598600
def block_mixing(*args, use_checkpoint=True, **kwargs):
599601
if use_checkpoint:

comfy/ldm/qwen_image/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ def __init__(
214214
operations=operations,
215215
)
216216

217-
def _modulate(self, x, mod_params):
218-
shift, scale, gate = mod_params.chunk(3, dim=-1)
219-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
217+
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
218+
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
219+
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
220220

221221
def forward(
222222
self,
@@ -248,11 +248,11 @@ def forward(
248248

249249
img_normed2 = self.img_norm2(hidden_states)
250250
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
251-
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
251+
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
252252

253253
txt_normed2 = self.txt_norm2(encoder_hidden_states)
254254
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
255-
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
255+
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
256256

257257
return encoder_hidden_states, hidden_states
258258

@@ -275,7 +275,7 @@ def __init__(
275275
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
276276
emb = self.linear(self.silu(conditioning_embedding))
277277
scale, shift = torch.chunk(emb, 2, dim=1)
278-
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
278+
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
279279
return x
280280

281281

comfy/ldm/wan/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def forward(self, x, context, context_img_len):
148148

149149
def repeat_e(e, x):
150150
repeats = 1
151-
if e.shape[1] > 1:
152-
repeats = x.shape[1] // e.shape[1]
151+
if e.size(1) > 1:
152+
repeats = x.size(1) // e.size(1)
153153
if repeats == 1:
154154
return e
155155
return torch.repeat_interleave(e, repeats, dim=1)
@@ -219,15 +219,15 @@ def forward(
219219

220220
# self-attention
221221
y = self.self_attn(
222-
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
222+
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
223223
freqs)
224224

225-
x = x + y * repeat_e(e[2], x)
225+
x = torch.addcmul(x, y, repeat_e(e[2], x))
226226

227227
# cross-attention & ffn
228228
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
229-
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
230-
x = x + y * repeat_e(e[5], x)
229+
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
230+
x = torch.addcmul(x, y, repeat_e(e[5], x))
231231
return x
232232

233233

@@ -342,7 +342,7 @@ def forward(self, x, e):
342342
else:
343343
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
344344

345-
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
345+
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
346346
return x
347347

348348

0 commit comments

Comments
 (0)