-
Couldn't load subscription status.
- Fork 286
[Example] Add efficient attention sink backward implementations and tests #877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced a new example script for multi-head attention (MHA) with sliding window attention and sink tokens. - Added a reference attention function to validate the implementation against PyTorch. - Included argument parsing for command-line execution of the example.
- Removed the old example script for multi-head attention (MHA) with sliding window attention and sink tokens. - Introduced a new example script that modifies the attention mechanism to enhance performance and maintainability. - Updated argument parsing and reference functions to align with the new implementation.
- Added a `window_size` parameter to the `flashattn` function to enable sliding window attention. - Implemented assertions to ensure `window_size` is compatible with `block_N`. - Updated the main function to include a `tune` option for performance tuning. - Introduced a new test file to validate both full attention and sliding window scenarios. - Adjusted FLOPS calculation to account for the sliding window configuration.
…modate the new sequence length parameters.
…rresponding test cases
…orresponding test case
|
Warning Rate limit exceeded@Rachmanino has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 8 minutes and 49 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📥 CommitsReviewing files that changed from the base of the PR and between 4b53138a5e8af6e5a7f52cb6ec51cfb27d4697ac and 96f807e. 📒 Files selected for processing (3)
WalkthroughAdds multiple TileLang attention examples (MHA and GQA) implementing tiled forward/backward (including sink-aware variants and optional sliding-window), autograd wrappers, PyTorch reference implementations, CLI/benchmark harnesses, tests, and a small API fix in a pipelined backward example (causal flag and shared-buffer shape changes). Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant App as main()/tests
participant Torch as PyTorch
participant Auto as _attention (autograd)
participant Kern as TileLang Kernels
App->>Torch: prepare Q,K,V,(Sinks), window_size?
App->>Auto: attention(Q,K,V, Sinks?, window_size?)
activate Auto
Auto->>Kern: flashattn_fwd(...)
Kern-->>Auto: O, lse
Auto-->>App: O (ctx saved: Q,K,V,Sinks,lse,window_size)
deactivate Auto
App->>Torch: loss.backward()
Torch->>Auto: backward(dO)
activate Auto
Auto->>Kern: flashattn_bwd_preprocess(O,dO) -> Delta
Auto->>Kern: flashattn_bwd_dsink(Sinks,Delta,lse) -> dSinks
Auto->>Kern: flashattn_bwd(Q,K,V,dO,lse,Delta, window_size?) -> dQ,dK,dV
Auto->>Kern: flashattn_bwd_postprocess(dQ_tile) -> dQ
Kern-->>Auto: dQ,dK,dV,dSinks
Auto-->>Torch: dq, dk, dv, dsinks
deactivate Auto
sequenceDiagram
autonumber
participant AutoG as _attention (GQA)
participant KernG as TileLang Kernels (GQA)
note over AutoG,KernG: GQA uses grouped heads (param `groups`)
AutoG->>KernG: flashattn_fwd(..., groups, window_size?)
KernG-->>AutoG: O, lse
AutoG->>KernG: flashattn_bwd_preprocess -> Delta
AutoG->>KernG: flashattn_bwd_dsink -> dSinks
AutoG->>KernG: flashattn_bwd(..., groups, window_size?) -> dQ,dK,dV
AutoG->>KernG: flashattn_bwd_postprocess -> dQ
KernG-->>AutoG: dQ,dK,dV,dSinks
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (6)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
10-18: Consider extracting the error message to a constant or class.The error message is quite verbose. Consider extracting it to improve readability and maintainability.
Apply this diff to improve the structure:
+UNSUPPORTED_SM_ERROR = "Unsupported SM version: {}" + def get_bwd_configs(): sm_major, sm_minor = torch.cuda.get_device_capability() sm_version = sm_major * 10 + sm_minor if sm_version == 80: return 64, 64, 1, 128 elif sm_version == 90: return 128, 128, 2, 256 else: - raise ValueError(f"Unsupported SM version: {sm_version}") + raise ValueError(UNSUPPORTED_SM_ERROR.format(sm_version))
168-171: Consider renaming ambiguous variablelfor improved readability.The variable name
lcan be easily confused with1orIin many fonts.Apply this diff to improve readability:
def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + lambda b, h, seq, d: [b, h, seq // 8, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])examples/flash_attention/example_mha_bwd_bhsd.py (2)
4-4: Remove wildcard import for better code clarity.Using wildcard imports can lead to namespace pollution and makes it harder to track dependencies.
Since the wildcard import from
tilelang.autotunerdoesn't appear to be used in this file, you can simply remove it:-from tilelang.autotuner import *
118-121: Consider renaming ambiguous variablelfor improved readability.The variable name
lcan be easily confused with1orIin many fonts.Apply this diff to improve readability:
def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + lambda b, h, seq, d: [b, h, seq // 8, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
10-18: Consider extracting the error message to a constant.Similar to the GQA implementation, consider extracting the error message for consistency.
Apply this diff to improve the structure:
+UNSUPPORTED_SM_ERROR = "Unsupported SM version: {}" + def get_bwd_configs(): sm_major, sm_minor = torch.cuda.get_device_capability() sm_version = sm_major * 10 + sm_minor if sm_version == 80: return 64, 64, 1, 128 elif sm_version == 90: return 128, 128, 2, 256 else: - raise ValueError(f"Unsupported SM version: {sm_version}") + raise ValueError(UNSUPPORTED_SM_ERROR.format(sm_version))
168-171: Consider renaming ambiguous variablelfor improved readability.The variable name
lcan be easily confused with1orIin many fonts.Apply this diff to improve readability:
def make_dq_layout(dQ): # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment return T.Layout(dQ.shape, - lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + lambda b, h, seq, d: [b, h, seq // 8, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py(1 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py(1 hunks)examples/attention_sink/test_example_attention_sink.py(2 hunks)examples/flash_attention/example_mha_bwd_bhsd.py(1 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py(1 hunks)examples/flash_attention/test_example_flash_attention.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/flash_attention/test_example_flash_attention.py (1)
examples/flash_attention/example_mha_bwd_bhsd.py (1)
main(297-346)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
tilelang/language/allocate.py (1)
alloc_shared(21-36)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
jit(237-310)tilelang/transform/pass_config.py (1)
PassConfigKey(6-96)tilelang/language/__init__.py (1)
annotate_layout(104-142)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (15)
get_bwd_configs(10-18)flashattn_fwd(25-132)flash_fwd(48-130)flashattn_bwd_preprocess(139-165)make_dq_layout(168-171)flashattn_bwd_postprocess(178-196)flashattn_bwd(202-304)flash_bwd(217-302)flashattn_bwd_dsink(308-334)flash_bwd_dsink(314-332)_attention(337-377)forward(340-347)backward(350-377)maybe_contiguous(355-358)ref_program(385-428)
examples/flash_attention/example_mha_bwd_bhsd.py (7)
tilelang/jit/__init__.py (1)
jit(237-310)tilelang/transform/pass_config.py (1)
PassConfigKey(6-96)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/__init__.py (1)
annotate_layout(104-142)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/fill.py (2)
fill(9-21)clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)
examples/attention_sink/test_example_attention_sink.py (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
main(431-492)examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
main(438-496)
🪛 Ruff (0.13.1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
18-18: Avoid specifying long messages outside the exception class
(TRY003)
147-147: Ambiguous variable name: O
(E741)
171-171: Ambiguous variable name: l
(E741)
457-457: Ambiguous variable name: O
(E741)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
18-18: Avoid specifying long messages outside the exception class
(TRY003)
147-147: Ambiguous variable name: O
(E741)
171-171: Ambiguous variable name: l
(E741)
461-461: Ambiguous variable name: O
(E741)
examples/flash_attention/example_mha_bwd_bhsd.py
4-4: from tilelang.autotuner import * used; unable to detect undefined names
(F403)
97-97: Ambiguous variable name: O
(E741)
121-121: Ambiguous variable name: l
(E741)
314-314: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-amd
🔇 Additional comments (10)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
187-195: Shared-buffer shape now matches fragment writesLine 187 change aligns the shared buffers with the
[block_M, dim]fragments we accumulate, eliminating the previous mismatch whenblock_M != block_N, and the swizzled layout annotations keep the wgmma-friendly ordering consistent. Looks good.examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
21-133: Well-implemented flash attention forward kernel with attention sink support.The implementation correctly incorporates attention sinks with sliding window support. The use of device capability detection and adaptive block configurations is appropriate, and the kernel includes proper infinity checks for sliding window scenarios.
199-305: Robust backward kernel implementation with proper gradient handling.The backward pass correctly computes gradients for Q, K, V using atomic operations for accumulation, properly handles sliding window masks, and efficiently manages shared memory with layout annotations.
examples/flash_attention/example_mha_bwd_bhsd.py (3)
9-83: Well-structured flash attention forward implementation.The forward kernel correctly implements the flash attention algorithm with proper memory layout annotations, efficient use of shared memory, and appropriate scaling factors.
149-240: Comprehensive backward pass with proper gradient accumulation.The backward kernel correctly handles gradient computation with atomic operations for thread-safe accumulation, efficient pipeline stages, and proper memory layout management through swizzled layouts.
265-266: Clarify block_N thresholds in backward bhsd example
The forward pass usesblock_N = 64 if D_HEAD ≤ 128 else 32(line 248), but the backward pass sets it to64 if D_HEAD ≤ 64 else 32(line 266). All other bhsd examples follow the same 128 threshold in forward but not in backward. If the lower backward threshold isn’t a deliberate performance tweak for this kernel, align both thresholds (use ≤ 128) for consistency.examples/flash_attention/test_example_flash_attention.py (1)
26-29: New test successfully integrates MHA backward BHSD example.The test follows the established pattern and correctly guards with CUDA requirement.
examples/attention_sink/test_example_attention_sink.py (1)
44-62: Comprehensive test coverage for attention sink backward implementations.The new tests appropriately cover both MHA and GQA variants with and without sliding windows, maintaining consistency with the existing test patterns.
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
21-133: Well-implemented flash attention forward with attention sink support.The forward kernel correctly incorporates attention sinks with sliding window logic and proper memory management. The commented-out Q_local code suggests potential future optimizations that have been thoughtfully deferred.
375-377: Ignore zero-initialization suggestion fordkanddv. Inexample_mha_sink_bwd_bhsd.py, unlikedq, bothdkanddvare fully written viaT.gemmandT.copy(noT.atomic_add), so usingtorch.emptyis intentional.Likely an incorrect or invalid review comment.
4b53138 to
1d1f48e
Compare
The second part of #831.
This pull request introduces a new backward pass implementation for multi-head attention using the BHSD layout, along with corresponding tests. The main focus is on adding the
example_mha_bwd_bhsd.pymodule, which provides a custom backward kernel for flash attention and integrates it into the test suite for both the attention sink and flash attention examples.New backward kernel implementation for flash attention (BHSD layout):
examples/flash_attention/example_mha_bwd_bhsd.py, which implements a custom backward kernel for multi-head attention with the BHSD layout using TileLang and PyTorch. This includes kernel definitions, autograd integration, reference checks, and benchmarking utilities.Test integration for the new kernel:
example_mha_bwd_bhsdinexamples/flash_attention/test_example_flash_attention.pyto validate the new backward kernel. [1] [2]example_mha_sink_bwd_bhsdinexamples/attention_sink/test_example_attention_sink.pyto ensure coverage in the attention sink test suite. [1] [2]Summary by CodeRabbit
New Features
Tests
Refactor