-
Couldn't load subscription status.
- Fork 286
[Example] Revert the atomic/split&sum templates in MHA backward examples #943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. WalkthroughConsolidates backward-pass implementations in two flash attention examples into a single flashattn_bwd kernel; removes atomic/split variants and related CLI flags; hard-codes thread/stage parameters; and simplifies _attention.forward and main signatures and their call sites. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as CLI
participant Main as main(...)
participant Attn as _attention.forward
participant KernF as flashattn_fwd
participant KernB as flashattn_bwd
CLI->>Main: parse args (causal)
Main->>Attn: attention(Q,K,V, causal)
Attn->>KernF: launch forward kernel
KernF-->>Attn: O
Note over Main,Attn: Backward pass (single path)
Main->>Attn: backward(dO)
Attn->>KernB: launch consolidated bwd
KernB-->>Attn: dQ,dK,dV
Attn-->>Main: grads
Note right of KernB: previous atomic/split branches removed
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/flash_attention/example_mha_bwd.py(7 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/flash_attention/example_mha_bwd.py (2)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (2)
flashattn_bwd(149-248)forward(254-262)tilelang/language/kernel.py (2)
Kernel(209-258)threads(195-199)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (3)
examples/flash_attention/example_mha_bwd.py (3)
flashattn_bwd(152-239)forward(245-252)main(297-344)tilelang/language/kernel.py (2)
Kernel(209-258)threads(195-199)tilelang/language/pipeline.py (1)
Pipelined(9-46)
🪛 Ruff (0.13.3)
examples/flash_attention/example_mha_bwd.py
314-314: Ambiguous variable name: O
(E741)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
324-324: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: bot-task
🔇 Additional comments (6)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (3)
171-174: Helpful documentation about memory optimization.The comments clarify when local fragment allocation should be avoided for large dimensions, which is valuable context for future maintainers.
254-262: LGTM! Simplified to single-path implementation.The removal of
use_atomicparameter correctly reflects the consolidation to a single backward kernel path.
280-287: LGTM! Clean single-path backward implementation.The backward method correctly invokes the consolidated
flashattn_bwdkernel with appropriate tensor allocations.examples/flash_attention/example_mha_bwd.py (3)
174-177: Helpful documentation about memory optimization.The comments clarify when local fragment allocation should be avoided for large dimensions, providing valuable guidance for future optimization work.
245-252: LGTM! Simplified to single-path implementation.The removal of
use_atomicparameter correctly reflects the consolidation to a single backward kernel path, consistent with the PR objectives.
270-277: LGTM! Clean single-path backward implementation.The backward method correctly invokes the consolidated
flashattn_bwdkernel with appropriate tensor allocations.
| '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') | ||
| parser.add_argument( | ||
| '--use_split', action='store_true', default=False, help='Use split for dK/dV') | ||
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix argparse boolean handling.
Using type=bool with argparse doesn't work as expected. Any non-empty string (including "False") will be converted to True. Command-line usage like --causal False will incorrectly set causal=True.
Apply this diff to fix the boolean argument parsing:
- parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+ parser.add_argument('--causal', action='store_true', help='Causal flag')Alternatively, if you need explicit True/False control:
- parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+ parser.add_argument('--causal', type=lambda x: x.lower() == 'true', default=False, help='Causal flag')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') | |
| parser.add_argument('--causal', action='store_true', help='Causal flag') |
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd_wgmma_pipelined.py around line 363,
the argparse boolean is defined with type=bool which treats any non-empty string
as True; replace it with a proper boolean parser by either using
action='store_true' (for a flag that defaults to False and becomes True when
present) or implement a small str2bool converter and use type=str2bool with
choices or default to correctly parse explicit "True"/"False" strings; update
the parser.add_argument call accordingly so command-line values like "--causal
False" behave as expected.
| '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') | ||
| parser.add_argument( | ||
| '--use_split', action='store_true', default=False, help='Use split for dK/dV') | ||
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix argparse boolean handling.
Using type=bool with argparse doesn't work as expected. Any non-empty string (including "False") will be converted to True. Command-line usage like --causal False will incorrectly set causal=True.
Apply this diff to fix the boolean argument parsing:
- parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+ parser.add_argument('--causal', action='store_true', help='Causal flag')Alternatively, if you need explicit True/False control:
- parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+ parser.add_argument('--causal', type=lambda x: x.lower() == 'true', default=False, help='Causal flag')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') | |
| parser.add_argument('--causal', action='store_true', help='Causal flag') |
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd.py around line 353, the argparse
call uses type=bool which treats any non-empty string (e.g. "False") as True;
replace it with a proper boolean flag by using parser.add_argument('--causal',
action='store_true', help='Causal flag') so passing --causal sets True and
omitting it leaves False; if you need explicit True/False parsing from strings
instead, add a small str2bool helper that maps common truthy/falsey strings to
booleans and use type=str2bool (and keep default=False).
As titled, revert the MHA part of #940, since there's no reduction for
dkanddvin MHA backward process.This pull request simplifies the backward pass implementation for FlashAttention in both
example_mha_bwd.pyandexample_mha_bwd_wgmma_pipelined.py. The main improvement is the removal of the dual kernel approach (atomic_addandsplit), consolidating them into a single backward kernel function. This streamlines both the code and the user interface, making the API easier to use and maintain.Refactoring and Kernel Consolidation
flashattn_bwd_atomic_addandflashattn_bwd_splitkernel definitions, replacing them with a unifiedflashattn_bwdfunction in both files. This eliminates code duplication and simplifies kernel selection logic. [1] [2]flashattn_bwdfunction, removing conditional logic and atomic/split flags from the backward pass. [1] [2]API and Argument Simplification
use_atomicanduse_splitarguments from themainfunctions, CLI parsers, and the autogradforwardmethods, simplifying the user interface and invocation. [1] [2] [3] [4] [5] [6] [7]Minor Implementation Updates
threadsandnum_stagesin the kernel definition, removing them from the function signatures for clarity and consistency. [1] [2] [3] [4]Output and Logging Clean-up
These changes make the backward pass codebase more maintainable and user-friendly, reducing complexity and potential confusion around kernel selection.
Summary by CodeRabbit
Refactor
Chores