Skip to content

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Oct 22, 2025

Summary by CodeRabbit

  • Tests

    • Added a CUDA test for the flash attention example and re-enabled gradient validation for the backward pass.
  • Examples

    • Default backward computation method now uses a split-based update for typical runs.
    • Improved backward accumulation semantics for correctness.
    • Example no longer sets a fixed random seed at import time.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 22, 2025

Walkthrough

The PR updates the GQA varlen FlashAttention example to use memory_order="release" for an atomic_add in the backward path, switches the default dQ update strategy to the split-based path when no flag is given, re-enables the dQ gradient assertion, removes an import-time seed, and adds a test that runs the example.

Changes

Cohort / File(s) Summary
GQA Backward Gradient Accumulation
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
atomic_add in backward now uses memory_order="release" for dQ accumulation; default backward update switches to split-based when neither --use_atomic nor --use_split are provided; dQ gradient assertion (assert_close) re-enabled; removed initial torch.manual_seed(1) call.
Test Suite Extension
examples/flash_attention/test_example_flash_attention.py
Added def test_example_gqa_bwd_tma_reduce_varlen() which imports and invokes example_gqa_bwd_tma_reduce_varlen.main(), adding a CUDA-requiring test case.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant CLI as CLI Flags
    participant Example as example_gqa_bwd_tma_reduce_varlen.main
    participant Backward as GQA Backward Path
    participant Atomic as atomic_add (dQ)
    participant Split as split-based update
    participant Test as test runner

    Note over CLI,Example: start example with flags (or none)
    CLI->>Example: invoke main()
    Example->>Backward: run backward accumulation
    alt --use_atomic
        Backward->>Atomic: choose atomic path
        Atomic->>Atomic: atomic_add(..., memory_order="release")
    else --use_split
        Backward->>Split: choose split-based update
    else (no flag)
        Backward->>Split: default to split-based update
    end
    Test->>Example: test invokes main() (CUDA)
    Note over Atomic,Split: dK/dV checks + dQ assertion re-enabled
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • chengyupku
  • LeiWang1999

Poem

🐰
I hopped through kernels, tidy and spry,
Switched splits for atomics when flags pass by.
dQ wakes up and checks its rhyme,
Memory release keeps timing in time.
A tiny hop for code, a carrot for CI. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[BugFix] Add memory order and testing script for split version GQA bwd kernel" is clearly related to the main changes in the changeset. It accurately describes two primary additions: the addition of memory_order="release" for dQ accumulation in the atomic operation, and the new test function test_example_gqa_bwd_tma_reduce_varlen(). The title is specific to the GQA backward kernel scope and appropriately framed as a bug fix. While the changeset also includes secondary adjustments like seed handling removal and default behavior changes, the title appropriately focuses on the key functional improvements without requiring exhaustive coverage of every detail.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9ab3bbb and 4d1ccc5.

📒 Files selected for processing (1)
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (2)

526-526: Elementwise atomic for dQ with release — OK; watch contention.

Fine for correctness. For throughput, consider block/warp-aggregated accumulation then a single atomic per row/col tile in a follow-up.


785-787: Default behavior mismatch between CLI and programmatic calls.

CLI defaults to split (use_atomic=False), but main() still defaults to True. Either flip main’s default to False or always pass use_atomic explicitly at call sites to avoid confusion.

- def main(..., use_atomic: bool = True):
+ def main(..., use_atomic: bool = False):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5683e6a and 9ab3bbb.

📒 Files selected for processing (2)
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (3 hunks)
  • examples/flash_attention/test_example_flash_attention.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/test_example_flash_attention.py (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
  • main (691-758)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (2)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (2)

367-372: Atomic adds now use release ordering — good. Please confirm necessity.

Change looks correct and harmless. If no in-kernel acquire reads of dQ/dK/dV occur, relaxed would also be sufficient and slightly cheaper. If you rely on HB-release visibility for downstream kernels, consider noting it in a comment.

Also applies to: 373-383


742-742: Re‑enable dQ check — good.

Looks fine; if atomic path is used, watch for rare non-deterministic tolerance trips. Current rtol/atol should be OK with fp32 accumulation.

Comment on lines +18 to 21
@tilelang.testing.requires_cuda
def test_example_gqa_bwd_tma_reduce_varlen():
example_gqa_bwd_tma_reduce_varlen.main()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Test does not exercise “split” path and lacks compute-capability guard.

As written, main() runs atomic path (default True) and may fail on < SM90 GPUs. Call split explicitly and add the same guard used elsewhere.

-@tilelang.testing.requires_cuda
-def test_example_gqa_bwd_tma_reduce_varlen():
-    example_gqa_bwd_tma_reduce_varlen.main()
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
+def test_example_gqa_bwd_tma_reduce_varlen():
+    example_gqa_bwd_tma_reduce_varlen.main(use_atomic=False)
🤖 Prompt for AI Agents
In examples/flash_attention/test_example_flash_attention.py around lines 18-21,
the test calls example_gqa_bwd_tma_reduce_varlen.main() which by default runs
the atomic path (may fail on GPUs < SM90); modify the test to explicitly invoke
the split path by calling main(split=True) and add the same compute-capability
guard used elsewhere (e.g., @tilelang.testing.requires_sm90 or equivalent) above
the test so it only runs on supported GPUs.

@LeiWang1999 LeiWang1999 merged commit 853f9c3 into tile-ai:main Oct 27, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants