Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions comfy/ldm/ace/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import nn

import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention

class Attention(nn.Module):
def __init__(
Expand Down Expand Up @@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def apply_rotary_emb(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -435,13 +432,9 @@ def __call__(
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
).to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down