-
Couldn't load subscription status.
- Fork 286
[Example] Add examples to support efficient attention sink forward process #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced a new example script for multi-head attention (MHA) with sliding window attention and sink tokens. - Added a reference attention function to validate the implementation against PyTorch. - Included argument parsing for command-line execution of the example.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
- Removed the old example script for multi-head attention (MHA) with sliding window attention and sink tokens. - Introduced a new example script that modifies the attention mechanism to enhance performance and maintainability. - Updated argument parsing and reference functions to align with the new implementation.
- Added a `window_size` parameter to the `flashattn` function to enable sliding window attention. - Implemented assertions to ensure `window_size` is compatible with `block_N`. - Updated the main function to include a `tune` option for performance tuning. - Introduced a new test file to validate both full attention and sliding window scenarios. - Adjusted FLOPS calculation to account for the sliding window configuration.
WalkthroughAdds three new attention-sink example programs (MHA and GQA variants, including WGMMA-pipelined kernels) with TileLang and Triton implementations, autotuning, reference correctness checks, input generators, benchmarking, and CLI entrypoints. Adds CUDA-gated tests exercising full and sliding-window modes for these examples. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant CLI as CLI/main()
participant Gen as gen_inputs()
participant TL as TileLang flashattn()
participant Ref as ref_program()
participant Tri as triton_program()
participant Bench as do_bench
User->>CLI: parse args (B,H,Sq,Skv,D,groups,window,tune)
alt tune mode
CLI->>TL: run autotuner & select config
TL-->>CLI: tuned kernel + latency
CLI->>Bench: benchmark tuned TL
Bench-->>CLI: results
else execute
CLI->>Gen: allocate Q,K,V,Sinks on CUDA
Gen-->>CLI: tensors
CLI->>TL: run TileLang kernel (with sinks/window)
TL-->>CLI: Out_TL
CLI->>Ref: compute ref_program
Ref-->>CLI: Out_Ref
CLI->>Tri: run Triton kernel
Tri-->>CLI: Out_Tri
CLI->>CLI: compare Out_TL vs Out_Ref and Out_Tri
CLI->>Bench: benchmark TL and Triton
Bench-->>CLI: latencies/throughput
end
Note over TL,Tri: Softmax/normalization incorporate per-head sinks and optional sliding-window masking
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…modate the new sequence length parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (10)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
13-14: Consider exposing more block sizes in autotuning configsThe current configurations only test
block_M=[128], block_N=[128], limiting the autotuning exploration space. For better performance across different problem sizes, consider including more block size variations.Consider expanding the configuration space:
-iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256]) +iter_params = dict(block_M=[64, 128], block_N=[64, 128], num_stages=[0, 1, 2], threads=[128, 256])
183-184: Remove unnecessary unsqueeze in ref_programThe unsqueeze operation on Line 183-184 adds an unnecessary dimension that's immediately handled within the reference implementation. This can be simplified.
Consider removing the unsqueeze:
-query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function's interface +query = query.transpose(1, 2).contiguous()examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (3)
195-196: Typo in comment: missing spaceThere's a missing space in the comment on Line 196.
Fix the typo:
- 3) # align with the original function'sinterface + 3) # align with the original function's interface
236-253: Remove unused parameters from triton_kernelThe static analysis correctly identifies that parameters
Z,N_Q_CTX, andN_KV_CTXare never used in the Triton kernel. These should be removed to clean up the interface.Remove the unused parameters:
@triton.jit def triton_kernel( Q, K, V, Sinks, sm_scale, Out, - Z, H, - N_Q_CTX, - N_KV_CTX, HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BANDWIDTH: tl.constexpr, start_q: tl.constexpr, ):And update the call site:
triton_kernel[grid]( TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]), TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]), TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]), Sinks, 1.0 / head_dim**0.5, TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]), - bs, n_heads, - N_Q_CTX=seq_q, - N_KV_CTX=seq_kv, HEAD_DIM=head_dim, BANDWIDTH=window_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, start_q=seq_kv - seq_q)
404-411: Consider more robust validation for Triton implementationThe current validation of the Triton implementation uses
torch.allclosewith a boolean check and prints pass/fail. Consider usingtorch.testing.assert_closefor consistency with the TileLang validation.Improve the validation approach:
-if torch.allclose( - triton_program(Q, K, V, sinks, window_size), - ref_program(Q, K, V, sinks, window_size), - rtol=1e-2, - atol=1e-2): - print("Checks for triton passed.✅") -else: - print("Checks for triton failed.❌") +try: + torch.testing.assert_close( + triton_program(Q, K, V, sinks, window_size), + ref_program(Q, K, V, sinks, window_size), + rtol=1e-2, + atol=1e-2) + print("Checks for triton passed.✅") +except AssertionError as e: + print(f"Checks for triton failed.❌\n{e}")examples/attention_sink/test_example_attention_sink.py (2)
3-5: Use absolute imports for better maintainabilityThe test file uses relative imports without the dot notation. Consider using absolute imports or explicit relative imports for clearer module dependencies.
Use explicit relative imports:
-import example_mha_sink_fwd_bhsd -import example_mha_sink_fwd_bhsd_wgmma_pipelined -import example_gqa_sink_fwd_bhsd_wgmma_pipelined +from . import example_mha_sink_fwd_bhsd +from . import example_mha_sink_fwd_bhsd_wgmma_pipelined +from . import example_gqa_sink_fwd_bhsd_wgmma_pipelined
9-40: Consider parameterized tests for better coverageThe tests use hardcoded values (e.g.,
window_size=128). Consider using parameterized tests to cover more configurations and edge cases.Would you like me to help create parameterized test cases that cover various batch sizes, sequence lengths, and window sizes to ensure robustness across different configurations?
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)
243-260: Remove unused parameters from triton_kernelSimilar to the MHA implementation, parameters
Z,N_Q_CTX, andN_KV_CTXare unused.Remove the unused parameters and update the call site accordingly (same pattern as suggested for the MHA implementation).
418-425: Inconsistent validation approach between implementationsThe GQA implementation uses the same suboptimal validation pattern for Triton as the MHA implementation.
Apply the same improved validation approach using
torch.testing.assert_closewith exception handling as suggested for the MHA implementation.
358-359: Consider more descriptive parameter namesThe
gen_inputsfunction parameter names use single letters (B, H, Sq, Skv, D), which reduces readability.Consider using full names:
-def gen_inputs(B, H, Sq, Skv, D, - groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +def gen_inputs(batch_size, num_heads, seq_len_q, seq_len_kv, head_dim, + groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py(1 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)examples/attention_sink/test_example_attention_sink.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
tilelang/autotuner/tuner.py (1)
autotune(692-785)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
get_configs(16-18)ref_program(196-239)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
get_configs(16-18)ref_program(189-233)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/pipeline.py (1)
Pipelined(9-46)
examples/attention_sink/test_example_attention_sink.py (4)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
main(133-189)main(367-435)examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
main(121-171)main(232-284)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
main(126-182)main(357-418)tilelang/testing/__init__.py (1)
requires_cuda_compute_version_ge(95-96)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (7)
tilelang/autotuner/tuner.py (1)
autotune(692-785)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/tir/op.py (1)
if_then_else(2824-2854)tilelang/language/reduce.py (2)
reduce_max(50-68)reduce_sum(87-109)tilelang/language/allocate.py (3)
alloc_shared(21-36)alloc_fragment(53-64)alloc_local(39-50)tilelang/language/pipeline.py (1)
Pipelined(9-46)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (7)
tilelang/autotuner/tuner.py (1)
autotune(692-785)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/tir/op.py (1)
if_then_else(2824-2854)tilelang/language/reduce.py (2)
reduce_max(50-68)reduce_sum(87-109)tilelang/language/allocate.py (3)
alloc_shared(21-36)alloc_fragment(53-64)alloc_local(39-50)tilelang/language/pipeline.py (1)
Pipelined(9-46)
🪛 Ruff (0.13.1)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
244-244: Unused function argument: Z
(ARG001)
246-246: Unused function argument: N_Q_CTX
(ARG001)
247-247: Unused function argument: N_KV_CTX
(ARG001)
261-261: Unused noqa directive (non-enabled: SIM108)
Remove unused noqa directive
(RUF100)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
250-250: Unused function argument: Z
(ARG001)
252-252: Unused function argument: N_Q_CTX
(ARG001)
253-253: Unused function argument: N_KV_CTX
(ARG001)
268-268: Unused noqa directive (non-enabled: SIM108)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: build-test-amd
🔇 Additional comments (1)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
146-146: Incorrect allocation shape for sinks fragment bufferThe
sinksfragment buffer is allocated with shape[heads], but onlyblock_Mvalues are accessed in Line 148. This is a mismatch that could lead to out-of-bounds access.Apply this fix:
-sinks = T.alloc_fragment([heads], dtype) +sinks = T.alloc_fragment([block_M], dtype)Likely an incorrect or invalid review comment.
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
Outdated
Show resolved
Hide resolved
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
Outdated
Show resolved
Hide resolved
…ned.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
…ned.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
250-253: Clean up unused Triton kernel parametersThe Triton kernel has unused parameters
Z,N_Q_CTX, andN_KV_CTXthat should be removed for clarity.-def triton_kernel( - Q, - K, - V, - Sinks, - sm_scale, - Out, - Z, - H, - N_Q_CTX, - N_KV_CTX, - HEAD_DIM: tl.constexpr, +def triton_kernel( + Q, + K, + V, + Sinks, + sm_scale, + Out, + H, + HEAD_DIM: tl.constexpr,And update the call site:
triton_kernel[grid]( TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]), TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]), TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]), Sinks, 1.0 / head_dim**0.5, TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]), - bs, n_heads, - N_Q_CTX=seq_q, - N_KV_CTX=seq_kv, HEAD_DIM=head_dim,
268-268: Remove unnecessarynoqacommentThe
noqa: SIM108directive is not needed here as SIM108 is not enabled in your linting configuration.-if Sinks is not None: # noqa: SIM108 +if Sinks is not None:examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
195-196: Typo in inline commentThere's a formatting issue in the comment on line 195-196.
-query = query.transpose(1, 2).contiguous().unsqueeze( - 3) # align with the original function'sinterface +query = query.transpose(1, 2).contiguous().unsqueeze( + 3) # align with the original function's interface
380-411: Consider extracting common benchmark codeBoth MHA and GQA example files contain nearly identical benchmarking logic. Consider extracting this into a shared utility function to reduce duplication and improve maintainability.
Would you like me to help create a shared benchmarking utility that could be used by both example files? This would reduce code duplication and make future maintenance easier.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/autotuner/tuner.py (1)
autotune(692-785)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (5)
flashattn(30-191)ref_program(196-239)triton_kernel(243-326)triton_program(329-355)gen_inputs(358-364)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)
tilelang/autotuner/tuner.py (1)
autotune(692-785)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/transform/pass_config.py (1)
PassConfigKey(6-89)
🪛 Ruff (0.13.1)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
244-244: Unused function argument: Z
(ARG001)
246-246: Unused function argument: N_Q_CTX
(ARG001)
247-247: Unused function argument: N_KV_CTX
(ARG001)
261-261: Unused noqa directive (non-enabled: SIM108)
Remove unused noqa directive
(RUF100)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
250-250: Unused function argument: Z
(ARG001)
252-252: Unused function argument: N_Q_CTX
(ARG001)
253-253: Unused function argument: N_KV_CTX
(ARG001)
268-268: Unused noqa directive (non-enabled: SIM108)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (4)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)
153-154: Fixed sinks fragment allocation shape.Great! The previous issue with incorrect allocation shape for
sinksfragment buffer has been resolved. The buffer is now correctly allocated with shape[block_M]to match its usage pattern.examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (3)
146-146: Fixed sinks fragment allocation shape.Good! The
sinksfragment allocation issue has been properly addressed with shape[block_M]matching its usage pattern.
244-247: Clean up unused Triton kernel parametersSimilar to the GQA file, the Triton kernel has unused parameters that should be removed.
261-261: Remove unnecessarynoqacommentThe
noqa: SIM108directive should be removed.
|
|
||
|
|
||
| def gen_inputs(B, H, Sq, Skv, D, | ||
| groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing groups parameter in gen_inputs signature
The gen_inputs function accepts a groups parameter but it's not used when called from main on line 412. You'll need to pass the groups parameter there.
Apply this fix:
-Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim)
+Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups) |
🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py around
line 359, gen_inputs declares a groups parameter but the call in main at line
412 does not pass it; update the call at line 412 to supply the groups argument
(e.g., pass the local variable named groups or the appropriate value/constant
you intended) so gen_inputs receives and can use groups.
TODO
seqlen_q<seqlen_kvSummary by CodeRabbit
New Features
Tests
Documentation