- 
                Notifications
    You must be signed in to change notification settings 
- Fork 290
          [Bugfix] Ensure correct handling for cases  where seq_q<seq_kv in flash attention examples
          #864
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -34,6 +34,9 @@ def flashattn(batch, | |||||||||||
| dtype = "float16" | ||||||||||||
| accum_dtype = "float" | ||||||||||||
|  | ||||||||||||
| past_len = seq_kv - seq_q | ||||||||||||
| assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" | ||||||||||||
|  | ||||||||||||
| 
      Comment on lines
    
      +37
     to 
      +39
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gate the length invariant to causal mode (don’t block valid non‑causal cases). Unconditionally asserting  -    past_len = seq_kv - seq_q
-    assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+    past_len = seq_kv - seq_q
+    if is_causal:
+        assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"📝 Committable suggestion
 
        Suggested change
       
 🤖 Prompt for AI Agents | ||||||||||||
| @T.macro | ||||||||||||
| def MMA0( | ||||||||||||
| K: T.Tensor(kv_shape, dtype), | ||||||||||||
|  | @@ -45,7 +48,6 @@ def MMA0( | |||||||||||
| by: T.int32, | ||||||||||||
| bz: T.int32, | ||||||||||||
| ): | ||||||||||||
| past_len = seq_kv - seq_q | ||||||||||||
| T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) | ||||||||||||
| if is_causal: | ||||||||||||
| for i, j in T.Parallel(block_M, block_N): | ||||||||||||
|  | @@ -135,8 +137,10 @@ def main( | |||||||||||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||||||||||||
|  | ||||||||||||
| loop_range = ( | ||||||||||||
| T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( | ||||||||||||
| (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) | ||||||||||||
| T.min( | ||||||||||||
| T.ceildiv(seq_kv, block_N), T.ceildiv( | ||||||||||||
| (bx + 1) * block_M + | ||||||||||||
| past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) | ||||||||||||
|  | ||||||||||||
| for k in T.Pipelined(loop_range, num_stages=num_stages): | ||||||||||||
| MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) | ||||||||||||
|  | @@ -159,7 +163,7 @@ def ref_program(Q, K, V, is_causal): | |||||||||||
| if is_causal: | ||||||||||||
| seq_q = Q.size(2) | ||||||||||||
| seq_kv = K.size(2) | ||||||||||||
| mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device)) | ||||||||||||
| mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) | ||||||||||||
| mask = mask.unsqueeze(0).unsqueeze(0) | ||||||||||||
| scores = scores.masked_fill(mask == 0, float('-inf')) | ||||||||||||
| attention_weights = F.softmax(scores, dim=-1) | ||||||||||||
|  | ||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -34,6 +34,9 @@ def flashattn(batch, | |||||||||||
| dtype = "float16" | ||||||||||||
| accum_dtype = "float" | ||||||||||||
|  | ||||||||||||
| past_len = seq_kv - seq_q | ||||||||||||
| assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" | ||||||||||||
|  | ||||||||||||
| 
      Comment on lines
    
      +37
     to 
      +39
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Restrict the  Same concern as the non‑pipelined variant: don’t prevent valid non‑causal  -    past_len = seq_kv - seq_q
-    assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+    past_len = seq_kv - seq_q
+    if is_causal:
+        assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"📝 Committable suggestion
 
        Suggested change
       
 | ||||||||||||
| @T.macro | ||||||||||||
| def MMA0( | ||||||||||||
| K: T.Tensor(kv_shape, dtype), | ||||||||||||
|  | @@ -45,7 +48,6 @@ def MMA0( | |||||||||||
| by: T.int32, | ||||||||||||
| bz: T.int32, | ||||||||||||
| ): | ||||||||||||
| past_len = seq_kv - seq_q | ||||||||||||
| T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) | ||||||||||||
| if is_causal: | ||||||||||||
| for i, j in T.Parallel(block_M, block_N): | ||||||||||||
|  | @@ -135,8 +137,10 @@ def main( | |||||||||||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||||||||||||
|  | ||||||||||||
| loop_range = ( | ||||||||||||
| T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( | ||||||||||||
| (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) | ||||||||||||
| T.min( | ||||||||||||
| T.ceildiv(seq_kv, block_N), T.ceildiv( | ||||||||||||
| (bx + 1) * block_M + | ||||||||||||
| past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) | ||||||||||||
|  | ||||||||||||
| for k in T.Pipelined( | ||||||||||||
| loop_range, | ||||||||||||
|  | @@ -164,7 +168,7 @@ def ref_program(Q, K, V, is_causal): | |||||||||||
| if is_causal: | ||||||||||||
| seq_q = Q.size(2) | ||||||||||||
| seq_kv = K.size(2) | ||||||||||||
| mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device)) | ||||||||||||
| mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) | ||||||||||||
| mask = mask.unsqueeze(0).unsqueeze(0) | ||||||||||||
| scores = scores.masked_fill(mask == 0, float('-inf')) | ||||||||||||
| attention_weights = F.softmax(scores, dim=-1) | ||||||||||||
|  | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While the logic here is correct, I've noticed that this change, and indeed most of the
flashattnfunction and theref_program, is duplicated inexample_mha_fwd_bhsd_wgmma_pipelined.py. This duplication increases the maintenance burden, as any future changes will need to be applied in both places.To improve maintainability, consider refactoring the common logic into a shared module. These example files could then import the common components and only define the parts that are specific to them (like the
T.Pipelinedloop configuration andmainfunction arguments).