Skip to content

Commit b74ba8d

Browse files
akaitsuki-iiqufei.qf
andauthored
no fa3 with attention mask (#163)
* no fa3 with attention mask * modify log level for fla attn_mask --------- Co-authored-by: qufei.qf <qufei.qf@alibaba-inc.com>
1 parent 665f74a commit b74ba8d

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

diffsynth_engine/models/basic/attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,13 @@ def attention(
135135
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
136136
if attn_impl is None or attn_impl == "auto":
137137
if FLASH_ATTN_3_AVAILABLE:
138-
if flash_attn3_compatible:
138+
if flash_attn3_compatible and attn_mask is None:
139139
return flash_attn3(q, k, v, softmax_scale=scale)
140140
else:
141-
logger.warning(
142-
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
143-
)
141+
if not flash_attn3_compatible:
142+
logger.warning(f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation")
143+
else:
144+
logger.debug("flash_attn_3 does not support attention mask, will use fallback attention implementation")
144145
if XFORMERS_AVAILABLE:
145146
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
146147
if SDPA_AVAILABLE:
@@ -156,6 +157,8 @@ def attention(
156157
raise RuntimeError(
157158
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
158159
)
160+
if attn_mask is not None:
161+
raise RuntimeError("flash_attn_3 does not support attention mask")
159162
return flash_attn3(q, k, v, softmax_scale=scale)
160163
if attn_impl == "flash_attn_2":
161164
return flash_attn2(q, k, v, softmax_scale=scale)

0 commit comments

Comments
 (0)