Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Sep 25, 2025

The second part of #831.

This pull request introduces a new backward pass implementation for multi-head attention using the BHSD layout, along with corresponding tests. The main focus is on adding the example_mha_bwd_bhsd.py module, which provides a custom backward kernel for flash attention and integrates it into the test suite for both the attention sink and flash attention examples.

New backward kernel implementation for flash attention (BHSD layout):

  • Added examples/flash_attention/example_mha_bwd_bhsd.py, which implements a custom backward kernel for multi-head attention with the BHSD layout using TileLang and PyTorch. This includes kernel definitions, autograd integration, reference checks, and benchmarking utilities.

Test integration for the new kernel:

  • Added import and test case for example_mha_bwd_bhsd in examples/flash_attention/test_example_flash_attention.py to validate the new backward kernel. [1] [2]
  • Added import and test case for example_mha_sink_bwd_bhsd in examples/attention_sink/test_example_attention_sink.py to ensure coverage in the attention sink test suite. [1] [2]

Summary by CodeRabbit

  • New Features

    • Added self-contained attention examples (MHA & GQA) with fused forward/backward, optional sliding-window, autograd-friendly API, reference implementations, CLI, and benchmarking.
    • Added a flash-attention backward example with verification and performance benchmarking.
  • Tests

    • Added tests exercising MHA/GQA attention sinks (including sliding-window variants) and a flash-attention backward test.
  • Refactor

    • Made causal-flag and shared-buffer usage consistent in a pipelined backward kernel.

Rachmanino and others added 17 commits September 20, 2025 09:57
- Introduced a new example script for multi-head attention (MHA) with sliding window attention and sink tokens.
- Added a reference attention function to validate the implementation against PyTorch.
- Included argument parsing for command-line execution of the example.
- Removed the old example script for multi-head attention (MHA) with sliding window attention and sink tokens.
- Introduced a new example script that modifies the attention mechanism to enhance performance and maintainability.
- Updated argument parsing and reference functions to align with the new implementation.
- Added a `window_size` parameter to the `flashattn` function to enable sliding window attention.
- Implemented assertions to ensure `window_size` is compatible with `block_N`.
- Updated the main function to include a `tune` option for performance tuning.
- Introduced a new test file to validate both full attention and sliding window scenarios.
- Adjusted FLOPS calculation to account for the sliding window configuration.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

Warning

Rate limit exceeded

@Rachmanino has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 8 minutes and 49 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 4b53138a5e8af6e5a7f52cb6ec51cfb27d4697ac and 96f807e.

📒 Files selected for processing (3)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (1 hunks)
  • examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (4 hunks)

Walkthrough

Adds multiple TileLang attention examples (MHA and GQA) implementing tiled forward/backward (including sink-aware variants and optional sliding-window), autograd wrappers, PyTorch reference implementations, CLI/benchmark harnesses, tests, and a small API fix in a pipelined backward example (causal flag and shared-buffer shape changes).

Changes

Cohort / File(s) Summary
Attention Sink — GQA example
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
New GQA attention-sink example: forward kernel, backward preprocess/postprocess, full backward, dsink kernel, autograd Function _attention, public attention alias, reference PyTorch implementation, CLI/benchmark, SM-version gating, optional sliding-window.
Attention Sink — MHA example
examples/attention_sink/example_mha_sink_bwd_bhsd.py
New MHA attention-sink example mirroring GQA: forward/backward kernels, preprocess/postprocess, dsink kernel, autograd wrapper, attention alias, reference PyTorch program, CLI/benchmark, SM checks, optional sliding-window.
Attention Sink — Tests
examples/attention_sink/test_example_attention_sink.py
Adds four tests invoking MHA/GQA backward examples (with and without sliding window) that call each module’s main().
Flash Attention — MHA backward (BHSD)
examples/flash_attention/example_mha_bwd_bhsd.py
New flash-attention BHSD example: fused forward/backward tilelang kernels, backward preprocess/postprocess, dQ layout utility, autograd wrapper _attention, attention alias, reference PyTorch, CLI/benchmark and verification.
Flash Attention — WGMMA pipelined tweak
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
API/implementation tweak: renamed param is_casualis_causal and changed shared-buffer allocations dv_shared/dk_shared shape from [block_N, dim] to [block_M, dim]; updated conditions referencing causal flag.
Flash Attention — Tests
examples/flash_attention/test_example_flash_attention.py
Adds CUDA-guarded test test_example_mha_bwd_bhsd() importing the new BHSD example.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant App as main()/tests
  participant Torch as PyTorch
  participant Auto as _attention (autograd)
  participant Kern as TileLang Kernels

  App->>Torch: prepare Q,K,V,(Sinks), window_size?
  App->>Auto: attention(Q,K,V, Sinks?, window_size?)
  activate Auto
  Auto->>Kern: flashattn_fwd(...)
  Kern-->>Auto: O, lse
  Auto-->>App: O (ctx saved: Q,K,V,Sinks,lse,window_size)
  deactivate Auto

  App->>Torch: loss.backward()
  Torch->>Auto: backward(dO)
  activate Auto
  Auto->>Kern: flashattn_bwd_preprocess(O,dO) -> Delta
  Auto->>Kern: flashattn_bwd_dsink(Sinks,Delta,lse) -> dSinks
  Auto->>Kern: flashattn_bwd(Q,K,V,dO,lse,Delta, window_size?) -> dQ,dK,dV
  Auto->>Kern: flashattn_bwd_postprocess(dQ_tile) -> dQ
  Kern-->>Auto: dQ,dK,dV,dSinks
  Auto-->>Torch: dq, dk, dv, dsinks
  deactivate Auto
Loading
sequenceDiagram
  autonumber
  participant AutoG as _attention (GQA)
  participant KernG as TileLang Kernels (GQA)
  note over AutoG,KernG: GQA uses grouped heads (param `groups`)

  AutoG->>KernG: flashattn_fwd(..., groups, window_size?)
  KernG-->>AutoG: O, lse
  AutoG->>KernG: flashattn_bwd_preprocess -> Delta
  AutoG->>KernG: flashattn_bwd_dsink -> dSinks
  AutoG->>KernG: flashattn_bwd(..., groups, window_size?) -> dQ,dK,dV
  AutoG->>KernG: flashattn_bwd_postprocess -> dQ
  KernG-->>AutoG: dQ,dK,dV,dSinks
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

I twitch my ears at kernels new,
Tiles align, the gradients flew—
Sinks that anchor, windows glide,
Backward rivers coincide.
I hop, I code, I benchmark too—hooray! 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly summarizes the main change by indicating the addition of efficient attention sink backward implementations along with corresponding tests, which directly matches the core content of the pull request without extra noise or ambiguity.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@Rachmanino Rachmanino marked this pull request as ready for review September 26, 2025 10:45
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (6)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)

10-18: Consider extracting the error message to a constant or class.

The error message is quite verbose. Consider extracting it to improve readability and maintainability.

Apply this diff to improve the structure:

+UNSUPPORTED_SM_ERROR = "Unsupported SM version: {}"
+
 def get_bwd_configs():
     sm_major, sm_minor = torch.cuda.get_device_capability()
     sm_version = sm_major * 10 + sm_minor
     if sm_version == 80:
         return 64, 64, 1, 128
     elif sm_version == 90:
         return 128, 128, 2, 256
     else:
-        raise ValueError(f"Unsupported SM version: {sm_version}")
+        raise ValueError(UNSUPPORTED_SM_ERROR.format(sm_version))

168-171: Consider renaming ambiguous variable l for improved readability.

The variable name l can be easily confused with 1 or I in many fonts.

Apply this diff to improve readability:

 def make_dq_layout(dQ):
     # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
     return T.Layout(dQ.shape,
-                    lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
+                    lambda b, h, seq, d: [b, h, seq // 8, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])
examples/flash_attention/example_mha_bwd_bhsd.py (2)

4-4: Remove wildcard import for better code clarity.

Using wildcard imports can lead to namespace pollution and makes it harder to track dependencies.

Since the wildcard import from tilelang.autotuner doesn't appear to be used in this file, you can simply remove it:

-from tilelang.autotuner import *

118-121: Consider renaming ambiguous variable l for improved readability.

The variable name l can be easily confused with 1 or I in many fonts.

Apply this diff to improve readability:

 def make_dq_layout(dQ):
     # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
     return T.Layout(dQ.shape,
-                    lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
+                    lambda b, h, seq, d: [b, h, seq // 8, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)

10-18: Consider extracting the error message to a constant.

Similar to the GQA implementation, consider extracting the error message for consistency.

Apply this diff to improve the structure:

+UNSUPPORTED_SM_ERROR = "Unsupported SM version: {}"
+
 def get_bwd_configs():
     sm_major, sm_minor = torch.cuda.get_device_capability()
     sm_version = sm_major * 10 + sm_minor
     if sm_version == 80:
         return 64, 64, 1, 128
     elif sm_version == 90:
         return 128, 128, 2, 256
     else:
-        raise ValueError(f"Unsupported SM version: {sm_version}")
+        raise ValueError(UNSUPPORTED_SM_ERROR.format(sm_version))

168-171: Consider renaming ambiguous variable l for improved readability.

The variable name l can be easily confused with 1 or I in many fonts.

Apply this diff to improve readability:

 def make_dq_layout(dQ):
     # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
     return T.Layout(dQ.shape,
-                    lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
+                    lambda b, h, seq, d: [b, h, seq // 8, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aa0b109 and 985931d.

📒 Files selected for processing (6)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (1 hunks)
  • examples/attention_sink/test_example_attention_sink.py (2 hunks)
  • examples/flash_attention/example_mha_bwd_bhsd.py (1 hunks)
  • examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/test_example_flash_attention.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/flash_attention/test_example_flash_attention.py (1)
examples/flash_attention/example_mha_bwd_bhsd.py (1)
  • main (297-346)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-96)
tilelang/language/__init__.py (1)
  • annotate_layout (104-142)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (15)
  • get_bwd_configs (10-18)
  • flashattn_fwd (25-132)
  • flash_fwd (48-130)
  • flashattn_bwd_preprocess (139-165)
  • make_dq_layout (168-171)
  • flashattn_bwd_postprocess (178-196)
  • flashattn_bwd (202-304)
  • flash_bwd (217-302)
  • flashattn_bwd_dsink (308-334)
  • flash_bwd_dsink (314-332)
  • _attention (337-377)
  • forward (340-347)
  • backward (350-377)
  • maybe_contiguous (355-358)
  • ref_program (385-428)
examples/flash_attention/example_mha_bwd_bhsd.py (7)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-96)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/__init__.py (1)
  • annotate_layout (104-142)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/fill.py (2)
  • fill (9-21)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
examples/attention_sink/test_example_attention_sink.py (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
  • main (431-492)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
  • main (438-496)
🪛 Ruff (0.13.1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py

18-18: Avoid specifying long messages outside the exception class

(TRY003)


147-147: Ambiguous variable name: O

(E741)


171-171: Ambiguous variable name: l

(E741)


457-457: Ambiguous variable name: O

(E741)

examples/attention_sink/example_mha_sink_bwd_bhsd.py

18-18: Avoid specifying long messages outside the exception class

(TRY003)


147-147: Ambiguous variable name: O

(E741)


171-171: Ambiguous variable name: l

(E741)


461-461: Ambiguous variable name: O

(E741)

examples/flash_attention/example_mha_bwd_bhsd.py

4-4: from tilelang.autotuner import * used; unable to detect undefined names

(F403)


97-97: Ambiguous variable name: O

(E741)


121-121: Ambiguous variable name: l

(E741)


314-314: Ambiguous variable name: O

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-amd
🔇 Additional comments (10)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)

187-195: Shared-buffer shape now matches fragment writes

Line 187 change aligns the shared buffers with the [block_M, dim] fragments we accumulate, eliminating the previous mismatch when block_M != block_N, and the swizzled layout annotations keep the wgmma-friendly ordering consistent. Looks good.

examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)

21-133: Well-implemented flash attention forward kernel with attention sink support.

The implementation correctly incorporates attention sinks with sliding window support. The use of device capability detection and adaptive block configurations is appropriate, and the kernel includes proper infinity checks for sliding window scenarios.


199-305: Robust backward kernel implementation with proper gradient handling.

The backward pass correctly computes gradients for Q, K, V using atomic operations for accumulation, properly handles sliding window masks, and efficiently manages shared memory with layout annotations.

examples/flash_attention/example_mha_bwd_bhsd.py (3)

9-83: Well-structured flash attention forward implementation.

The forward kernel correctly implements the flash attention algorithm with proper memory layout annotations, efficient use of shared memory, and appropriate scaling factors.


149-240: Comprehensive backward pass with proper gradient accumulation.

The backward kernel correctly handles gradient computation with atomic operations for thread-safe accumulation, efficient pipeline stages, and proper memory layout management through swizzled layouts.


265-266: Clarify block_N thresholds in backward bhsd example
The forward pass uses block_N = 64 if D_HEAD ≤ 128 else 32 (line 248), but the backward pass sets it to 64 if D_HEAD ≤ 64 else 32 (line 266). All other bhsd examples follow the same 128 threshold in forward but not in backward. If the lower backward threshold isn’t a deliberate performance tweak for this kernel, align both thresholds (use ≤ 128) for consistency.

examples/flash_attention/test_example_flash_attention.py (1)

26-29: New test successfully integrates MHA backward BHSD example.

The test follows the established pattern and correctly guards with CUDA requirement.

examples/attention_sink/test_example_attention_sink.py (1)

44-62: Comprehensive test coverage for attention sink backward implementations.

The new tests appropriately cover both MHA and GQA variants with and without sliding windows, maintaining consistency with the existing test patterns.

examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)

21-133: Well-implemented flash attention forward with attention sink support.

The forward kernel correctly incorporates attention sinks with sliding window logic and proper memory management. The commented-out Q_local code suggests potential future optimizations that have been thoughtfully deferred.


375-377: Ignore zero-initialization suggestion for dk and dv. In example_mha_sink_bwd_bhsd.py, unlike dq, both dk and dv are fully written via T.gemm and T.copy (no T.atomic_add), so using torch.empty is intentional.

Likely an incorrect or invalid review comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants