Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions examples/flash_attention/example_mha_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def flashattn(batch,
dtype = "float16"
accum_dtype = "float"

past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
Comment on lines +37 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the logic here is correct, I've noticed that this change, and indeed most of the flashattn function and the ref_program, is duplicated in example_mha_fwd_bhsd_wgmma_pipelined.py. This duplication increases the maintenance burden, as any future changes will need to be applied in both places.

To improve maintainability, consider refactoring the common logic into a shared module. These example files could then import the common components and only define the parts that are specific to them (like the T.Pipelined loop configuration and main function arguments).


Comment on lines +37 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Gate the length invariant to causal mode (don’t block valid non‑causal cases).

Unconditionally asserting seq_kv >= seq_q forbids legitimate non‑causal runs where seq_q > seq_kv (e.g., cross‑attention). Restrict the assertion to is_causal.

-    past_len = seq_kv - seq_q
-    assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+    past_len = seq_kv - seq_q
+    if is_causal:
+        assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
past_len = seq_kv - seq_q
if is_causal:
assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_fwd_bhsd.py around lines 37–39, the
unconditional calculation and assert of past_len (past_len = seq_kv - seq_q;
assert past_len >= 0) incorrectly forbids valid non‑causal cases; guard this
logic with the is_causal check: only compute past_len and assert seq_kv >= seq_q
when is_causal is true, and skip or set a safe default for past_len in
non‑causal paths so cross‑attention (seq_q > seq_kv) is allowed.

@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Expand All @@ -45,7 +48,6 @@ def MMA0(
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
Expand Down Expand Up @@ -135,8 +137,10 @@ def main(
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))

for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Expand All @@ -159,7 +163,7 @@ def ref_program(Q, K, V, is_causal):
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
Expand Down
12 changes: 8 additions & 4 deletions examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def flashattn(batch,
dtype = "float16"
accum_dtype = "float"

past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"

Comment on lines +37 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Restrict the seq_kv >= seq_q precondition to causal mode.

Same concern as the non‑pipelined variant: don’t prevent valid non‑causal seq_q > seq_kv cases.

-    past_len = seq_kv - seq_q
-    assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+    past_len = seq_kv - seq_q
+    if is_causal:
+        assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
past_len = seq_kv - seq_q
if is_causal:
assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"

@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Expand All @@ -45,7 +48,6 @@ def MMA0(
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
Expand Down Expand Up @@ -135,8 +137,10 @@ def main(
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))

for k in T.Pipelined(
loop_range,
Expand Down Expand Up @@ -164,7 +168,7 @@ def ref_program(Q, K, V, is_causal):
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device))
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
Expand Down
Loading