Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions diffsynth_engine/models/basic/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,13 @@ def attention(
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
if attn_impl is None or attn_impl == "auto":
if FLASH_ATTN_3_AVAILABLE:
if flash_attn3_compatible:
if flash_attn3_compatible and attn_mask is None:
return flash_attn3(q, k, v, softmax_scale=scale)
else:
logger.warning(
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
)
if not flash_attn3_compatible:
logger.warning(f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation")
else:
logger.debug("flash_attn_3 does not support attention mask, will use fallback attention implementation")
if XFORMERS_AVAILABLE:
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
if SDPA_AVAILABLE:
Expand All @@ -156,6 +157,8 @@ def attention(
raise RuntimeError(
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
)
if attn_mask is not None:
raise RuntimeError("flash_attn_3 does not support attention mask")
return flash_attn3(q, k, v, softmax_scale=scale)
if attn_impl == "flash_attn_2":
return flash_attn2(q, k, v, softmax_scale=scale)
Expand Down