Skip to content

Commit 42da274

Browse files
authored
Use normal ComfyUI attention in ACE-Steps model (Comfy-Org#8023)
* Use normal ComfyUI attention in ACE-Steps model * Let optimized_attention handle output reshape for ACE
1 parent 28f178a commit 42da274

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

comfy/ldm/ace/attention.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import nn
2020

2121
import comfy.model_management
22+
from comfy.ldm.modules.attention import optimized_attention
2223

2324
class Attention(nn.Module):
2425
def __init__(
@@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
326327
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
327328
"""
328329

329-
def __init__(self):
330-
if not hasattr(F, "scaled_dot_product_attention"):
331-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
332-
333330
def apply_rotary_emb(
334331
self,
335332
x: torch.Tensor,
@@ -435,13 +432,9 @@ def __call__(
435432
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
436433

437434
# the output of sdp = (batch, num_heads, seq_len, head_dim)
438-
# TODO: add support for attn.scale when we move to Torch 2.1
439-
hidden_states = F.scaled_dot_product_attention(
440-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
441-
)
442-
443-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
444-
hidden_states = hidden_states.to(query.dtype)
435+
hidden_states = optimized_attention(
436+
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
437+
).to(query.dtype)
445438

446439
# linear proj
447440
hidden_states = attn.to_out[0](hidden_states)

0 commit comments

Comments
 (0)