Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Sep 23, 2025

This pull request updates both the example_mha_fwd_bhsd.py and example_mha_fwd_bhsd_wgmma_pipelined.py files to improve causal masking logic and ensure correct handling of sequence lengths in flash attention implementations. The main changes involve more robust calculation of the past sequence length, improved assertions, and more accurate mask generation for causal attention.

Causal masking and sequence length handling:

  • Added calculation of past_len as seq_kv - seq_q at the start of flashattn, with an assertion to ensure seq_kv >= seq_q for both files. This centralizes logic and prevents errors. [1] [2]
  • Updated the mask generation in ref_program to use torch.tril(..., seq_kv - seq_q), which creates a more accurate causal mask when seq_kv > seq_q. [1] [2]

Block loop range calculation:

  • Modified the calculation of loop_range in the main function to account for past_len when determining the block range in the causal case, ensuring correct coverage of key-value blocks. [1] [2]

Code simplification:

  • Removed redundant calculation of past_len from inside the MMA0 macro, since it is now computed once at the top level. [1] [2]

Summary by CodeRabbit

  • New Features

    • Added support for causal attention when key/value length exceeds query length, enabling past-context handling in example MHA forward paths.
  • Bug Fixes

    • Corrected causal masking and loop bounds to handle differing query and key/value sequence lengths reliably.
    • Added validation to prevent negative past-context lengths.
  • Tests

    • Updated example reference behavior to align masking with past-context offsets, ensuring consistency across implementations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 23, 2025

Walkthrough

Introduces past_len = seq_kv - seq_q with validation, removes local redefinitions, and updates causal loop range calculations and reference masks to account for sequence length offsets in two Flash Attention MHA forward example kernels.

Changes

Cohort / File(s) Summary
Flash Attention MHA forward (BHSD)
examples/flash_attention/example_mha_fwd_bhsd.py
Add global past_len = seq_kv - seq_q with assert; remove local past_len in MMA0; adjust causal loop_range to include past_len; update reference causal mask to use offset in torch.tril.
Flash Attention MHA forward pipelined (WGMMA)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
Add/validate past_len; remove local past_len in MMA0; modify causal loop_range in forward path and main kernel loop to include past_len; update reference mask with torch.tril offset.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Kernel as MHA Fwd Kernel
  participant MMA0 as MMA0 Block
  participant Ref as Reference Mask

  Caller->>Kernel: forward(seq_q, seq_kv, is_causal)
  Kernel->>Kernel: past_len = seq_kv - seq_q (assert past_len >= 0)
  alt is_causal
    Kernel->>Kernel: compute loop_range with past_len-adjusted bounds
    Kernel->>MMA0: launch tiles (uses outer past_len)
  else not causal
    Kernel->>Kernel: compute loop_range without offset
    Kernel->>MMA0: launch tiles
  end
  Note over Kernel,MMA0: Local past_len redefinition removed

  Caller->>Ref: build causal mask
  Ref->>Ref: torch.tril(..., offset = seq_kv - seq_q)
  Ref-->>Caller: masked attention output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

I twitch my ears at sliding time,
past_len hops in measured rhyme.
Loops now bound where keys outpace,
tril drops masks with offset grace.
Two kernels nibble bugs away—
behold, causality in play! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly and succinctly identifies the primary change: a bugfix to handle cases where seq_q < seq_kv in flash-attention example files, which matches the PR objectives and file-level summaries; it is specific to the examples and directly reflects the main code adjustments.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Rachmanino, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug in flash attention examples by enhancing the handling of causal masking and sequence lengths, particularly in scenarios where the query sequence length is less than the key-value sequence length (seq_q < seq_kv). The changes ensure the flash attention implementations correctly apply causal masks and iterate over key-value blocks, improving the robustness and accuracy of the attention mechanism.

Highlights

  • Causal Masking Logic: Introduced a centralized past_len calculation (seq_kv - seq_q) with an assertion (past_len >= 0) at the start of the flashattn function to ensure seq_kv is always greater than or equal to seq_q.
  • Accurate Mask Generation: Updated the ref_program to use torch.tril(..., seq_kv - seq_q) for mask generation, providing a more precise causal mask when seq_kv exceeds seq_q.
  • Block Loop Range Adjustment: Modified the loop_range calculation within the main function to correctly incorporate past_len for causal cases, ensuring proper coverage of key-value blocks.
  • Code Refinement: Eliminated redundant past_len calculations from within the MMA0 macro, as this value is now computed once at a higher level.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in the causal masking logic for flash attention examples when seq_kv > seq_q. The changes are applied consistently across both example_mha_fwd_bhsd.py and example_mha_fwd_bhsd_wgmma_pipelined.py by centralizing the past_len calculation, updating the loop range for key-value blocks, and adjusting the mask generation in the reference implementation. The changes are logical and improve correctness. My only feedback is a suggestion to address the significant code duplication between the two modified example files to improve long-term maintainability.

Comment on lines +37 to +38
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the logic here is correct, I've noticed that this change, and indeed most of the flashattn function and the ref_program, is duplicated in example_mha_fwd_bhsd_wgmma_pipelined.py. This duplication increases the maintenance burden, as any future changes will need to be applied in both places.

To improve maintainability, consider refactoring the common logic into a shared module. These example files could then import the common components and only define the parts that are specific to them (like the T.Pipelined loop configuration and main function arguments).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
examples/flash_attention/example_mha_fwd_bhsd.py (2)

139-144: Same coverage, clearer expression for loop_range.

Logic is right; consider the equivalent, slightly clearer form.

-            loop_range = (
-                T.min(
-                    T.ceildiv(seq_kv, block_N), T.ceildiv(
-                        (bx + 1) * block_M +
-                        past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
+            loop_range = (
+                T.ceildiv(T.min(seq_kv, (bx + 1) * block_M + past_len), block_N)
+                if is_causal else T.ceildiv(seq_kv, block_N))

166-169: Ref mask matches kernel semantics; minor boolean-mask tidy.

Optional: build the mask as bool to avoid == 0 and use the named diagonal arg for readability.

-        mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
-        mask = mask.unsqueeze(0).unsqueeze(0)
-        scores = scores.masked_fill(mask == 0, float('-inf'))
+        mask = torch.tril(
+            torch.ones(seq_q, seq_kv, device=scores.device, dtype=torch.bool),
+            diagonal=seq_kv - seq_q
+        ).unsqueeze(0).unsqueeze(0)
+        scores = scores.masked_fill(~mask, float('-inf'))
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (2)

139-144: Loop bound rewrite for readability (optional).

Functionally equivalent but a bit clearer to read.

-            loop_range = (
-                T.min(
-                    T.ceildiv(seq_kv, block_N), T.ceildiv(
-                        (bx + 1) * block_M +
-                        past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
+            loop_range = (
+                T.ceildiv(T.min(seq_kv, (bx + 1) * block_M + past_len), block_N)
+                if is_causal else T.ceildiv(seq_kv, block_N))

171-174: Ref mask offset is right; consider boolean mask shortcut.

Same optional tidy as above.

-        mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
-        mask = mask.unsqueeze(0).unsqueeze(0)
-        scores = scores.masked_fill(mask == 0, float('-inf'))
+        mask = torch.tril(
+            torch.ones(seq_q, seq_kv, device=scores.device, dtype=torch.bool),
+            diagonal=seq_kv - seq_q
+        ).unsqueeze(0).unsqueeze(0)
+        scores = scores.masked_fill(~mask, float('-inf'))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3b21a67 and 6136ef0.

📒 Files selected for processing (2)
  • examples/flash_attention/example_mha_fwd_bhsd.py (3 hunks)
  • examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
  • ceildiv (87-88)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (1)
  • ceildiv (87-88)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (2)
examples/flash_attention/example_mha_fwd_bhsd.py (1)

52-58: Causal mask math with past_len looks correct.

q_idx = bx*block_M + i + past_len vs k_idx = k*block_N + j correctly enables j <= i + past_len. Please sanity‑check extreme edges (e.g., seq_q=1, seq_kv≫seq_q; and seq_q=seq_kv) on both small and misaligned block sizes.

examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)

52-58: Causal masking with offset is correct; edge‑case sweep advised.

Approach mirrors the first file; please run a quick sweep over varied seq/block sizes to confirm no off‑by‑one at block boundaries.

Comment on lines +37 to +39
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Restrict the seq_kv >= seq_q precondition to causal mode.

Same concern as the non‑pipelined variant: don’t prevent valid non‑causal seq_q > seq_kv cases.

-    past_len = seq_kv - seq_q
-    assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+    past_len = seq_kv - seq_q
+    if is_causal:
+        assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
past_len = seq_kv - seq_q
if is_causal:
assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"

Comment on lines +37 to +39
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Gate the length invariant to causal mode (don’t block valid non‑causal cases).

Unconditionally asserting seq_kv >= seq_q forbids legitimate non‑causal runs where seq_q > seq_kv (e.g., cross‑attention). Restrict the assertion to is_causal.

-    past_len = seq_kv - seq_q
-    assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
+    past_len = seq_kv - seq_q
+    if is_causal:
+        assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
past_len = seq_kv - seq_q
if is_causal:
assert past_len >= 0, "In causal mode, require seq_kv >= seq_q"
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_fwd_bhsd.py around lines 37–39, the
unconditional calculation and assert of past_len (past_len = seq_kv - seq_q;
assert past_len >= 0) incorrectly forbids valid non‑causal cases; guard this
logic with the is_causal check: only compute past_len and assert seq_kv >= seq_q
when is_causal is true, and skip or set a safe default for past_len in
non‑causal paths so cross‑attention (seq_q > seq_kv) is allowed.

@LeiWang1999 LeiWang1999 merged commit b12a63c into tile-ai:main Sep 23, 2025
6 of 7 checks passed
@Rachmanino Rachmanino deleted the fix-fa branch September 23, 2025 05:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants