-
Notifications
You must be signed in to change notification settings - Fork 290
[Bugfix] Ensure correct handling for cases where seq_q<seq_kv in flash attention examples
#864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughIntroduces past_len = seq_kv - seq_q with validation, removes local redefinitions, and updates causal loop range calculations and reference masks to account for sequence length offsets in two Flash Attention MHA forward example kernels. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Kernel as MHA Fwd Kernel
participant MMA0 as MMA0 Block
participant Ref as Reference Mask
Caller->>Kernel: forward(seq_q, seq_kv, is_causal)
Kernel->>Kernel: past_len = seq_kv - seq_q (assert past_len >= 0)
alt is_causal
Kernel->>Kernel: compute loop_range with past_len-adjusted bounds
Kernel->>MMA0: launch tiles (uses outer past_len)
else not causal
Kernel->>Kernel: compute loop_range without offset
Kernel->>MMA0: launch tiles
end
Note over Kernel,MMA0: Local past_len redefinition removed
Caller->>Ref: build causal mask
Ref->>Ref: torch.tril(..., offset = seq_kv - seq_q)
Ref-->>Caller: masked attention output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @Rachmanino, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a bug in flash attention examples by enhancing the handling of causal masking and sequence lengths, particularly in scenarios where the query sequence length is less than the key-value sequence length ( Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a bug in the causal masking logic for flash attention examples when seq_kv > seq_q. The changes are applied consistently across both example_mha_fwd_bhsd.py and example_mha_fwd_bhsd_wgmma_pipelined.py by centralizing the past_len calculation, updating the loop range for key-value blocks, and adjusting the mask generation in the reference implementation. The changes are logical and improve correctness. My only feedback is a suggestion to address the significant code duplication between the two modified example files to improve long-term maintainability.
| past_len = seq_kv - seq_q | ||
| assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While the logic here is correct, I've noticed that this change, and indeed most of the flashattn function and the ref_program, is duplicated in example_mha_fwd_bhsd_wgmma_pipelined.py. This duplication increases the maintenance burden, as any future changes will need to be applied in both places.
To improve maintainability, consider refactoring the common logic into a shared module. These example files could then import the common components and only define the parts that are specific to them (like the T.Pipelined loop configuration and main function arguments).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (4)
examples/flash_attention/example_mha_fwd_bhsd.py (2)
139-144: Same coverage, clearer expression for loop_range.Logic is right; consider the equivalent, slightly clearer form.
- loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + loop_range = ( + T.ceildiv(T.min(seq_kv, (bx + 1) * block_M + past_len), block_N) + if is_causal else T.ceildiv(seq_kv, block_N))
166-169: Ref mask matches kernel semantics; minor boolean-mask tidy.Optional: build the mask as bool to avoid
== 0and use the nameddiagonalarg for readability.- mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) - mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + mask = torch.tril( + torch.ones(seq_q, seq_kv, device=scores.device, dtype=torch.bool), + diagonal=seq_kv - seq_q + ).unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(~mask, float('-inf'))examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (2)
139-144: Loop bound rewrite for readability (optional).Functionally equivalent but a bit clearer to read.
- loop_range = ( - T.min( - T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M + - past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + loop_range = ( + T.ceildiv(T.min(seq_kv, (bx + 1) * block_M + past_len), block_N) + if is_causal else T.ceildiv(seq_kv, block_N))
171-174: Ref mask offset is right; consider boolean mask shortcut.Same optional tidy as above.
- mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) - mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float('-inf')) + mask = torch.tril( + torch.ones(seq_q, seq_kv, device=scores.device, dtype=torch.bool), + diagonal=seq_kv - seq_q + ).unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(~mask, float('-inf'))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/flash_attention/example_mha_fwd_bhsd.py(3 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
ceildiv(87-88)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
ceildiv(87-88)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (2)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
52-58: Causal mask math with past_len looks correct.
q_idx = bx*block_M + i + past_lenvsk_idx = k*block_N + jcorrectly enablesj <= i + past_len. Please sanity‑check extreme edges (e.g., seq_q=1, seq_kv≫seq_q; and seq_q=seq_kv) on both small and misaligned block sizes.examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
52-58: Causal masking with offset is correct; edge‑case sweep advised.Approach mirrors the first file; please run a quick sweep over varied seq/block sizes to confirm no off‑by‑one at block boundaries.
| past_len = seq_kv - seq_q | ||
| assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restrict the seq_kv >= seq_q precondition to causal mode.
Same concern as the non‑pipelined variant: don’t prevent valid non‑causal seq_q > seq_kv cases.
- past_len = seq_kv - seq_q
- assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+ past_len = seq_kv - seq_q
+ if is_causal:
+ assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| past_len = seq_kv - seq_q | |
| assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" | |
| past_len = seq_kv - seq_q | |
| if is_causal: | |
| assert past_len >= 0, "In causal mode, require seq_kv >= seq_q" |
| past_len = seq_kv - seq_q | ||
| assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gate the length invariant to causal mode (don’t block valid non‑causal cases).
Unconditionally asserting seq_kv >= seq_q forbids legitimate non‑causal runs where seq_q > seq_kv (e.g., cross‑attention). Restrict the assertion to is_causal.
- past_len = seq_kv - seq_q
- assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+ past_len = seq_kv - seq_q
+ if is_causal:
+ assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| past_len = seq_kv - seq_q | |
| assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" | |
| past_len = seq_kv - seq_q | |
| if is_causal: | |
| assert past_len >= 0, "In causal mode, require seq_kv >= seq_q" |
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_fwd_bhsd.py around lines 37–39, the
unconditional calculation and assert of past_len (past_len = seq_kv - seq_q;
assert past_len >= 0) incorrectly forbids valid non‑causal cases; guard this
logic with the is_causal check: only compute past_len and assert seq_kv >= seq_q
when is_causal is true, and skip or set a safe default for past_len in
non‑causal paths so cross‑attention (seq_q > seq_kv) is allowed.
This pull request updates both the
example_mha_fwd_bhsd.pyandexample_mha_fwd_bhsd_wgmma_pipelined.pyfiles to improve causal masking logic and ensure correct handling of sequence lengths in flash attention implementations. The main changes involve more robust calculation of the past sequence length, improved assertions, and more accurate mask generation for causal attention.Causal masking and sequence length handling:
past_lenasseq_kv - seq_qat the start offlashattn, with an assertion to ensureseq_kv >= seq_qfor both files. This centralizes logic and prevents errors. [1] [2]ref_programto usetorch.tril(..., seq_kv - seq_q), which creates a more accurate causal mask whenseq_kv > seq_q. [1] [2]Block loop range calculation:
loop_rangein the main function to account forpast_lenwhen determining the block range in the causal case, ensuring correct coverage of key-value blocks. [1] [2]Code simplification:
past_lenfrom inside theMMA0macro, since it is now computed once at the top level. [1] [2]Summary by CodeRabbit
New Features
Bug Fixes
Tests