-
Notifications
You must be signed in to change notification settings - Fork 290
[BugFix] Add memory order and testing script for split version GQA bwd kernel #1100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThe PR updates the GQA varlen FlashAttention example to use Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as CLI Flags
participant Example as example_gqa_bwd_tma_reduce_varlen.main
participant Backward as GQA Backward Path
participant Atomic as atomic_add (dQ)
participant Split as split-based update
participant Test as test runner
Note over CLI,Example: start example with flags (or none)
CLI->>Example: invoke main()
Example->>Backward: run backward accumulation
alt --use_atomic
Backward->>Atomic: choose atomic path
Atomic->>Atomic: atomic_add(..., memory_order="release")
else --use_split
Backward->>Split: choose split-based update
else (no flag)
Backward->>Split: default to split-based update
end
Test->>Example: test invokes main() (CUDA)
Note over Atomic,Split: dK/dV checks + dQ assertion re-enabled
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (2)
526-526: Elementwise atomic for dQ with release — OK; watch contention.Fine for correctness. For throughput, consider block/warp-aggregated accumulation then a single atomic per row/col tile in a follow-up.
785-787: Default behavior mismatch between CLI and programmatic calls.CLI defaults to split (use_atomic=False), but main() still defaults to True. Either flip main’s default to False or always pass use_atomic explicitly at call sites to avoid confusion.
- def main(..., use_atomic: bool = True): + def main(..., use_atomic: bool = False):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py(3 hunks)examples/flash_attention/test_example_flash_attention.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/test_example_flash_attention.py (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
main(691-758)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (2)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (2)
367-372: Atomic adds now use release ordering — good. Please confirm necessity.Change looks correct and harmless. If no in-kernel acquire reads of dQ/dK/dV occur, relaxed would also be sufficient and slightly cheaper. If you rely on HB-release visibility for downstream kernels, consider noting it in a comment.
Also applies to: 373-383
742-742: Re‑enable dQ check — good.Looks fine; if atomic path is used, watch for rare non-deterministic tolerance trips. Current rtol/atol should be OK with fp32 accumulation.
| @tilelang.testing.requires_cuda | ||
| def test_example_gqa_bwd_tma_reduce_varlen(): | ||
| example_gqa_bwd_tma_reduce_varlen.main() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test does not exercise “split” path and lacks compute-capability guard.
As written, main() runs atomic path (default True) and may fail on < SM90 GPUs. Call split explicitly and add the same guard used elsewhere.
-@tilelang.testing.requires_cuda
-def test_example_gqa_bwd_tma_reduce_varlen():
- example_gqa_bwd_tma_reduce_varlen.main()
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
+def test_example_gqa_bwd_tma_reduce_varlen():
+ example_gqa_bwd_tma_reduce_varlen.main(use_atomic=False)🤖 Prompt for AI Agents
In examples/flash_attention/test_example_flash_attention.py around lines 18-21,
the test calls example_gqa_bwd_tma_reduce_varlen.main() which by default runs
the atomic path (may fail on GPUs < SM90); modify the test to explicitly invoke
the split path by calling main(split=True) and add the same compute-capability
guard used elsewhere (e.g., @tilelang.testing.requires_sm90 or equivalent) above
the test so it only runs on supported GPUs.
Summary by CodeRabbit
Tests
Examples