|
19 | 19 | from torch import nn |
20 | 20 |
|
21 | 21 | import comfy.model_management |
| 22 | +from comfy.ldm.modules.attention import optimized_attention |
22 | 23 |
|
23 | 24 | class Attention(nn.Module): |
24 | 25 | def __init__( |
@@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0: |
326 | 327 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
327 | 328 | """ |
328 | 329 |
|
329 | | - def __init__(self): |
330 | | - if not hasattr(F, "scaled_dot_product_attention"): |
331 | | - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
332 | | - |
333 | 330 | def apply_rotary_emb( |
334 | 331 | self, |
335 | 332 | x: torch.Tensor, |
@@ -435,13 +432,9 @@ def __call__( |
435 | 432 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
436 | 433 |
|
437 | 434 | # the output of sdp = (batch, num_heads, seq_len, head_dim) |
438 | | - # TODO: add support for attn.scale when we move to Torch 2.1 |
439 | | - hidden_states = F.scaled_dot_product_attention( |
440 | | - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
441 | | - ) |
442 | | - |
443 | | - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
444 | | - hidden_states = hidden_states.to(query.dtype) |
| 435 | + hidden_states = optimized_attention( |
| 436 | + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, |
| 437 | + ).to(query.dtype) |
445 | 438 |
|
446 | 439 | # linear proj |
447 | 440 | hidden_states = attn.to_out[0](hidden_states) |
|
0 commit comments