Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 53 additions & 33 deletions diffsynth_engine/models/basic/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
SPARGE_ATTN_AVAILABLE,
)

FA3_MAX_HEADDIM = 256

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -130,31 +132,40 @@ def attention(
"sage_attn",
"sparge_attn",
]
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
if attn_impl is None or attn_impl == "auto":
if FLASH_ATTN_3_AVAILABLE:
return flash_attn3(q, k, v, softmax_scale=scale)
elif XFORMERS_AVAILABLE:
if flash_attn3_compatible:
return flash_attn3(q, k, v, softmax_scale=scale)
else:
logger.warning(
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
)
if XFORMERS_AVAILABLE:
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif SDPA_AVAILABLE:
if SDPA_AVAILABLE:
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif FLASH_ATTN_2_AVAILABLE:
if FLASH_ATTN_2_AVAILABLE:
return flash_attn2(q, k, v, softmax_scale=scale)
else:
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
else:
if attn_impl == "eager":
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "flash_attn_3":
if attn_impl == "flash_attn_3":
if not flash_attn3_compatible:
raise RuntimeError(
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
)
return flash_attn3(q, k, v, softmax_scale=scale)
elif attn_impl == "flash_attn_2":
if attn_impl == "flash_attn_2":
return flash_attn2(q, k, v, softmax_scale=scale)
elif attn_impl == "xformers":
if attn_impl == "xformers":
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "sdpa":
if attn_impl == "sdpa":
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "sage_attn":
if attn_impl == "sage_attn":
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
elif attn_impl == "sparge_attn":
if attn_impl == "sparge_attn":
return sparge_attn(
q,
k,
Expand All @@ -166,8 +177,7 @@ def attention(
cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98),
pvthreshd=kwargs.get("sparge_pvthreshd", 50),
)
else:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
raise ValueError(f"Invalid attention implementation: {attn_impl}")


class Attention(nn.Module):
Expand Down Expand Up @@ -240,32 +250,42 @@ def long_context_attention(
"sage_attn",
"sparge_attn",
]
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
if attn_impl is None or attn_impl == "auto":
if FLASH_ATTN_3_AVAILABLE:
attn_func = LongContextAttention(attn_type=AttnType.FA3)
elif SDPA_AVAILABLE:
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
elif FLASH_ATTN_2_AVAILABLE:
attn_func = LongContextAttention(attn_type=AttnType.FA)
else:
raise ValueError("No available long context attention implementation")
if flash_attn3_compatible:
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
else:
logger.warning(
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
)
if SDPA_AVAILABLE:
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
if FLASH_ATTN_2_AVAILABLE:
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
raise ValueError("No available long context attention implementation")
else:
if attn_impl == "flash_attn_3":
attn_func = LongContextAttention(attn_type=AttnType.FA3)
elif attn_impl == "flash_attn_2":
attn_func = LongContextAttention(attn_type=AttnType.FA)
elif attn_impl == "sdpa":
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
elif attn_impl == "sage_attn":
attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
elif attn_impl == "sparge_attn":
if flash_attn3_compatible:
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
else:
raise RuntimeError(
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
)
if attn_impl == "flash_attn_2":
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
if attn_impl == "sdpa":
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
if attn_impl == "sage_attn":
return LongContextAttention(attn_type=AttnType.SAGE_FP8)(q, k, v, softmax_scale=scale)
if attn_impl == "sparge_attn":
attn_processor = SparseAttentionMeansim()
# default args from spas_sage2_attn_meansim_cuda
attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True))
attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6))
attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98))
attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50))
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)
else:
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
return attn_func(q, k, v, softmax_scale=scale)
return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
q, k, v, softmax_scale=scale
)
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")