Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Oct 6, 2025

As titled, revert the MHA part of #940, since there's no reduction for dk and dv in MHA backward process.
This pull request simplifies the backward pass implementation for FlashAttention in both example_mha_bwd.py and example_mha_bwd_wgmma_pipelined.py. The main improvement is the removal of the dual kernel approach (atomic_add and split), consolidating them into a single backward kernel function. This streamlines both the code and the user interface, making the API easier to use and maintain.

Refactoring and Kernel Consolidation

  • Removed the separate flashattn_bwd_atomic_add and flashattn_bwd_split kernel definitions, replacing them with a unified flashattn_bwd function in both files. This eliminates code duplication and simplifies kernel selection logic. [1] [2]
  • Updated the kernel launch logic to always use the consolidated flashattn_bwd function, removing conditional logic and atomic/split flags from the backward pass. [1] [2]

API and Argument Simplification

  • Removed the use_atomic and use_split arguments from the main functions, CLI parsers, and the autograd forward methods, simplifying the user interface and invocation. [1] [2] [3] [4] [5] [6] [7]

Minor Implementation Updates

  • Hardcoded kernel parameters such as threads and num_stages in the kernel definition, removing them from the function signatures for clarity and consistency. [1] [2] [3] [4]
  • Added comments and minor code clarifications regarding local storage usage for large dimensions. [1] [2]

Output and Logging Clean-up

  • Removed unnecessary print statements and extra return values for a cleaner output and interface. [1] [2] [3]

These changes make the backward pass codebase more maintainable and user-friendly, reducing complexity and potential confusion around kernel selection.

Summary by CodeRabbit

  • Refactor

    • Consolidated multiple internal backward implementations into a single, consistent backward path for flash attention.
    • Streamlined forward/backward flow and stabilized runtime threading for predictable performance.
    • Simplified public interfaces by removing atomic/split toggles.
  • Chores

    • Removed deprecated CLI flags and updated example usage and docs to match the simplified interface.

@github-actions
Copy link

github-actions bot commented Oct 6, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 6, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

Consolidates backward-pass implementations in two flash attention examples into a single flashattn_bwd kernel; removes atomic/split variants and related CLI flags; hard-codes thread/stage parameters; and simplifies _attention.forward and main signatures and their call sites.

Changes

Cohort / File(s) Change summary
Backward path consolidation
examples/flash_attention/example_mha_bwd.py, examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
Removed flashattn_bwd_atomic_add and split variants; all backward calls now invoke a single flashattn_bwd(...) implementation; eliminated branching and auxiliary kernel exports.
API signature simplification
examples/flash_attention/example_mha_bwd.py, examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
_attention.forward(ctx, q, k, v, causal, use_atomic=True)_attention.forward(ctx, q, k, v, causal); main(..., causal: bool = False, use_atomic: bool = True)main(..., causal: bool = False); updated call sites to drop use_atomic/use_split.
CLI and argument handling
examples/flash_attention/example_mha_bwd.py, examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
Removed --use_atomic and --use_split CLI flags and related parsing/printing; CLI now accepts only --causal.
Kernel launch and staging adjustments
examples/flash_attention/example_mha_bwd.py, examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
Hardened launch params: threads fixed to 128 at call sites; removed num_stages usage and hard-coded stage counts where applicable; updated comments about storage/fragments.
Public/private symbol cleanup
examples/flash_attention/example_mha_bwd.py
Deleted declarations/exports for flashattn_bwd_atomic_add and flashattn_bwd_split.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as CLI
  participant Main as main(...)
  participant Attn as _attention.forward
  participant KernF as flashattn_fwd
  participant KernB as flashattn_bwd
  CLI->>Main: parse args (causal)
  Main->>Attn: attention(Q,K,V, causal)
  Attn->>KernF: launch forward kernel
  KernF-->>Attn: O
  Note over Main,Attn: Backward pass (single path)
  Main->>Attn: backward(dO)
  Attn->>KernB: launch consolidated bwd
  KernB-->>Attn: dQ,dK,dV
  Attn-->>Main: grads
  Note right of KernB: previous atomic/split branches removed
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Poem

hop, hop—one kernel now, no forks to find,
gradients march forward, tidy in a line.
threads set steady at one-two-eight,
flags gone quiet — simplicity’s great.
i twitch my nose: consolidated, refined. 🐰✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly describes the main change of reverting the atomic/split&sum templates in the MHA backward example code, making it obvious to reviewers what part of the code is affected without extraneous detail.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71218d6 and dd0dbe3.

📒 Files selected for processing (2)
  • examples/flash_attention/example_mha_bwd.py (8 hunks)
  • examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (8 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Rachmanino Rachmanino requested a review from LeiWang1999 October 6, 2025 04:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3aecab8 and 71218d6.

📒 Files selected for processing (2)
  • examples/flash_attention/example_mha_bwd.py (7 hunks)
  • examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/flash_attention/example_mha_bwd.py (2)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (2)
  • flashattn_bwd (149-248)
  • forward (254-262)
tilelang/language/kernel.py (2)
  • Kernel (209-258)
  • threads (195-199)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (3)
examples/flash_attention/example_mha_bwd.py (3)
  • flashattn_bwd (152-239)
  • forward (245-252)
  • main (297-344)
tilelang/language/kernel.py (2)
  • Kernel (209-258)
  • threads (195-199)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
🪛 Ruff (0.13.3)
examples/flash_attention/example_mha_bwd.py

314-314: Ambiguous variable name: O

(E741)

examples/flash_attention/example_mha_bwd_wgmma_pipelined.py

324-324: Ambiguous variable name: O

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: bot-task
🔇 Additional comments (6)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (3)

171-174: Helpful documentation about memory optimization.

The comments clarify when local fragment allocation should be avoided for large dimensions, which is valuable context for future maintainers.


254-262: LGTM! Simplified to single-path implementation.

The removal of use_atomic parameter correctly reflects the consolidation to a single backward kernel path.


280-287: LGTM! Clean single-path backward implementation.

The backward method correctly invokes the consolidated flashattn_bwd kernel with appropriate tensor allocations.

examples/flash_attention/example_mha_bwd.py (3)

174-177: Helpful documentation about memory optimization.

The comments clarify when local fragment allocation should be avoided for large dimensions, providing valuable guidance for future optimization work.


245-252: LGTM! Simplified to single-path implementation.

The removal of use_atomic parameter correctly reflects the consolidation to a single backward kernel path, consistent with the PR objectives.


270-277: LGTM! Clean single-path backward implementation.

The backward method correctly invokes the consolidated flashattn_bwd kernel with appropriate tensor allocations.

'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix argparse boolean handling.

Using type=bool with argparse doesn't work as expected. Any non-empty string (including "False") will be converted to True. Command-line usage like --causal False will incorrectly set causal=True.

Apply this diff to fix the boolean argument parsing:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', action='store_true', help='Causal flag')

Alternatively, if you need explicit True/False control:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', type=lambda x: x.lower() == 'true', default=False, help='Causal flag')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument('--causal', action='store_true', help='Causal flag')
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd_wgmma_pipelined.py around line 363,
the argparse boolean is defined with type=bool which treats any non-empty string
as True; replace it with a proper boolean parser by either using
action='store_true' (for a flag that defaults to False and becomes True when
present) or implement a small str2bool converter and use type=str2bool with
choices or default to correctly parse explicit "True"/"False" strings; update
the parser.add_argument call accordingly so command-line values like "--causal
False" behave as expected.

'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix argparse boolean handling.

Using type=bool with argparse doesn't work as expected. Any non-empty string (including "False") will be converted to True. Command-line usage like --causal False will incorrectly set causal=True.

Apply this diff to fix the boolean argument parsing:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', action='store_true', help='Causal flag')

Alternatively, if you need explicit True/False control:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', type=lambda x: x.lower() == 'true', default=False, help='Causal flag')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument('--causal', action='store_true', help='Causal flag')
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd.py around line 353, the argparse
call uses type=bool which treats any non-empty string (e.g. "False") as True;
replace it with a proper boolean flag by using parser.add_argument('--causal',
action='store_true', help='Causal flag') so passing --causal sets True and
omitting it leaves False; if you need explicit True/False parsing from strings
instead, add a small str2bool helper that maps common truthy/falsey strings to
booleans and use type=str2bool (and keep default=False).

@LeiWang1999 LeiWang1999 merged commit 481cae4 into tile-ai:main Oct 6, 2025
3 checks passed
@coderabbitai coderabbitai bot mentioned this pull request Oct 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants