Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Sep 20, 2025

TODO

  • MHA+attention sink
  • sliding windows attention+attention sink
  • Implement wgmma pipelined version for optimal performance on Hopper
  • Add a benchmark to compare with official triton impl
  • Optimize sliding window attention performance to surpass triton
  • Fix cases where seqlen_q<seqlen_kv
  • Support GQA cases

Summary by CodeRabbit

  • New Features

    • Added multiple attention-sink forward-pass examples (MHA and GQA) with TileLang and Triton implementations, optional sliding-window support, autotuning, benchmarking, and CLI-driven orchestration.
  • Tests

    • Added CUDA-gated tests covering full and sliding-window modes for MHA and GQA, including high-performance pipelined variants; tests validate correctness across implementations.
  • Documentation

    • Runnable examples include input generators, reference implementations, performance reporting, and tuning/configuration helpers.

- Introduced a new example script for multi-head attention (MHA) with sliding window attention and sink tokens.
- Added a reference attention function to validate the implementation against PyTorch.
- Included argument parsing for command-line execution of the example.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Rachmanino and others added 4 commits September 22, 2025 10:18
- Removed the old example script for multi-head attention (MHA) with sliding window attention and sink tokens.
- Introduced a new example script that modifies the attention mechanism to enhance performance and maintainability.
- Updated argument parsing and reference functions to align with the new implementation.
- Added a `window_size` parameter to the `flashattn` function to enable sliding window attention.
- Implemented assertions to ensure `window_size` is compatible with `block_N`.
- Updated the main function to include a `tune` option for performance tuning.
- Introduced a new test file to validate both full attention and sliding window scenarios.
- Adjusted FLOPS calculation to account for the sliding window configuration.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 22, 2025

Walkthrough

Adds three new attention-sink example programs (MHA and GQA variants, including WGMMA-pipelined kernels) with TileLang and Triton implementations, autotuning, reference correctness checks, input generators, benchmarking, and CLI entrypoints. Adds CUDA-gated tests exercising full and sliding-window modes for these examples.

Changes

Cohort / File(s) Summary of modifications
MHA attention sink examples
examples/attention_sink/example_mha_sink_fwd_bhsd.py, examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
New FlashAttention-style forward implementations with attention-sink integration: TileLang-jitted/autotuned kernels (config grids), internal MMA/Softmax/Rescale macros, optional sliding-window support, Triton reference kernel (pipelined file), PyTorch reference, input generators, benchmarking/validation harness, and CLI entrypoints.
GQA attention sink example (WGMMA pipelined)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
New GQA-oriented attention forward example: Hopper/WGMMA-optimized TileLang kernel, Triton wrapper/kernel, ref_program, autotune/get_configs, gen_inputs, benchmarking and correctness checks, and CLI entrypoint.
Tests for attention sink examples
examples/attention_sink/test_example_attention_sink.py
New CUDA-dependent tests that invoke example main() functions for full and sliding-window modes; WGMMA-pipelined tests gated on CUDA compute capability ≥ 9.0; includes script entrypoint for running tests standalone.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    actor User
    participant CLI as CLI/main()
    participant Gen as gen_inputs()
    participant TL as TileLang flashattn()
    participant Ref as ref_program()
    participant Tri as triton_program()
    participant Bench as do_bench

    User->>CLI: parse args (B,H,Sq,Skv,D,groups,window,tune)
    alt tune mode
        CLI->>TL: run autotuner & select config
        TL-->>CLI: tuned kernel + latency
        CLI->>Bench: benchmark tuned TL
        Bench-->>CLI: results
    else execute
        CLI->>Gen: allocate Q,K,V,Sinks on CUDA
        Gen-->>CLI: tensors
        CLI->>TL: run TileLang kernel (with sinks/window)
        TL-->>CLI: Out_TL
        CLI->>Ref: compute ref_program
        Ref-->>CLI: Out_Ref
        CLI->>Tri: run Triton kernel
        Tri-->>CLI: Out_Tri
        CLI->>CLI: compare Out_TL vs Out_Ref and Out_Tri
        CLI->>Bench: benchmark TL and Triton
        Bench-->>CLI: latencies/throughput
    end
    Note over TL,Tri: Softmax/normalization incorporate per-head sinks and optional sliding-window masking
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

I hopped through heads where queries sing,
Sliding windows, sinks in the ring.
Triton streams and TileLang rhyme,
Autotune tunes to kernel time.
Benchmarks hum — carrots at the end. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly and concisely describes the primary change: adding examples to support an efficient attention-sink forward process, which matches the changeset that introduces multiple example implementations, benchmarks, and tests for attention sinks. It is specific enough for a reviewer to understand the main intent without listing files or extraneous detail, and it avoids vague terms or noisy formatting. The title therefore aligns with the PR objectives and the provided file summaries.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Rachmanino Rachmanino changed the title [Example] Add a new example to support attention sink forward process [Example] Add examples to support efficient attention sink forward process Sep 22, 2025
@Rachmanino Rachmanino marked this pull request as ready for review September 23, 2025 11:30
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (10)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)

13-14: Consider exposing more block sizes in autotuning configs

The current configurations only test block_M=[128], block_N=[128], limiting the autotuning exploration space. For better performance across different problem sizes, consider including more block size variations.

Consider expanding the configuration space:

-iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
+iter_params = dict(block_M=[64, 128], block_N=[64, 128], num_stages=[0, 1, 2], threads=[128, 256])

183-184: Remove unnecessary unsqueeze in ref_program

The unsqueeze operation on Line 183-184 adds an unnecessary dimension that's immediately handled within the reference implementation. This can be simplified.

Consider removing the unsqueeze:

-query = query.transpose(1, 2).contiguous().unsqueeze(
-    3)  # align with the original function's interface
+query = query.transpose(1, 2).contiguous()
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (3)

195-196: Typo in comment: missing space

There's a missing space in the comment on Line 196.

Fix the typo:

-    3)  # align with the original function'sinterface
+    3)  # align with the original function's interface

236-253: Remove unused parameters from triton_kernel

The static analysis correctly identifies that parameters Z, N_Q_CTX, and N_KV_CTX are never used in the Triton kernel. These should be removed to clean up the interface.

Remove the unused parameters:

@triton.jit
def triton_kernel(
    Q,
    K,
    V,
    Sinks,
    sm_scale,
    Out,
-   Z,
    H,
-   N_Q_CTX,
-   N_KV_CTX,
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BANDWIDTH: tl.constexpr,
    start_q: tl.constexpr,
):

And update the call site:

triton_kernel[grid](
    TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
    TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
    TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
    Sinks,
    1.0 / head_dim**0.5,
    TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
-   bs,
    n_heads,
-   N_Q_CTX=seq_q,
-   N_KV_CTX=seq_kv,
    HEAD_DIM=head_dim,
    BANDWIDTH=window_size,
    BLOCK_M=BLOCK_M,
    BLOCK_N=BLOCK_N,
    start_q=seq_kv - seq_q)

404-411: Consider more robust validation for Triton implementation

The current validation of the Triton implementation uses torch.allclose with a boolean check and prints pass/fail. Consider using torch.testing.assert_close for consistency with the TileLang validation.

Improve the validation approach:

-if torch.allclose(
-        triton_program(Q, K, V, sinks, window_size),
-        ref_program(Q, K, V, sinks, window_size),
-        rtol=1e-2,
-        atol=1e-2):
-    print("Checks for triton passed.✅")
-else:
-    print("Checks for triton failed.❌")
+try:
+    torch.testing.assert_close(
+        triton_program(Q, K, V, sinks, window_size),
+        ref_program(Q, K, V, sinks, window_size),
+        rtol=1e-2,
+        atol=1e-2)
+    print("Checks for triton passed.✅")
+except AssertionError as e:
+    print(f"Checks for triton failed.❌\n{e}")
examples/attention_sink/test_example_attention_sink.py (2)

3-5: Use absolute imports for better maintainability

The test file uses relative imports without the dot notation. Consider using absolute imports or explicit relative imports for clearer module dependencies.

Use explicit relative imports:

-import example_mha_sink_fwd_bhsd
-import example_mha_sink_fwd_bhsd_wgmma_pipelined
-import example_gqa_sink_fwd_bhsd_wgmma_pipelined
+from . import example_mha_sink_fwd_bhsd
+from . import example_mha_sink_fwd_bhsd_wgmma_pipelined  
+from . import example_gqa_sink_fwd_bhsd_wgmma_pipelined

9-40: Consider parameterized tests for better coverage

The tests use hardcoded values (e.g., window_size=128). Consider using parameterized tests to cover more configurations and edge cases.

Would you like me to help create parameterized test cases that cover various batch sizes, sequence lengths, and window sizes to ensure robustness across different configurations?

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)

243-260: Remove unused parameters from triton_kernel

Similar to the MHA implementation, parameters Z, N_Q_CTX, and N_KV_CTX are unused.

Remove the unused parameters and update the call site accordingly (same pattern as suggested for the MHA implementation).


418-425: Inconsistent validation approach between implementations

The GQA implementation uses the same suboptimal validation pattern for Triton as the MHA implementation.

Apply the same improved validation approach using torch.testing.assert_close with exception handling as suggested for the MHA implementation.


358-359: Consider more descriptive parameter names

The gen_inputs function parameter names use single letters (B, H, Sq, Skv, D), which reduces readability.

Consider using full names:

-def gen_inputs(B, H, Sq, Skv, D,
-               groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+def gen_inputs(batch_size, num_heads, seq_len_q, seq_len_kv, head_dim,
+               groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a3497eb and eef32b3.

📒 Files selected for processing (4)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py (1 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/attention_sink/test_example_attention_sink.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
  • get_configs (16-18)
  • ref_program (196-239)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
  • get_configs (16-18)
  • ref_program (189-233)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
examples/attention_sink/test_example_attention_sink.py (4)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
  • main (133-189)
  • main (367-435)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
  • main (121-171)
  • main (232-284)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
  • main (126-182)
  • main (357-418)
tilelang/testing/__init__.py (1)
  • requires_cuda_compute_version_ge (95-96)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (7)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/tir/op.py (1)
  • if_then_else (2824-2854)
tilelang/language/reduce.py (2)
  • reduce_max (50-68)
  • reduce_sum (87-109)
tilelang/language/allocate.py (3)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
  • alloc_local (39-50)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (7)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/tir/op.py (1)
  • if_then_else (2824-2854)
tilelang/language/reduce.py (2)
  • reduce_max (50-68)
  • reduce_sum (87-109)
tilelang/language/allocate.py (3)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
  • alloc_local (39-50)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
🪛 Ruff (0.13.1)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

244-244: Unused function argument: Z

(ARG001)


246-246: Unused function argument: N_Q_CTX

(ARG001)


247-247: Unused function argument: N_KV_CTX

(ARG001)


261-261: Unused noqa directive (non-enabled: SIM108)

Remove unused noqa directive

(RUF100)

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

250-250: Unused function argument: Z

(ARG001)


252-252: Unused function argument: N_Q_CTX

(ARG001)


253-253: Unused function argument: N_KV_CTX

(ARG001)


268-268: Unused noqa directive (non-enabled: SIM108)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-nvidia
  • GitHub Check: build-test-amd
🔇 Additional comments (1)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)

146-146: Incorrect allocation shape for sinks fragment buffer

The sinks fragment buffer is allocated with shape [heads], but only block_M values are accessed in Line 148. This is a mismatch that could lead to out-of-bounds access.

Apply this fix:

-sinks = T.alloc_fragment([heads], dtype)
+sinks = T.alloc_fragment([block_M], dtype)

Likely an incorrect or invalid review comment.

Rachmanino and others added 2 commits September 23, 2025 20:11
…ned.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
…ned.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)

250-253: Clean up unused Triton kernel parameters

The Triton kernel has unused parameters Z, N_Q_CTX, and N_KV_CTX that should be removed for clarity.

-def triton_kernel(
-    Q,
-    K,
-    V,
-    Sinks,
-    sm_scale,
-    Out,
-    Z,
-    H,
-    N_Q_CTX,
-    N_KV_CTX,
-    HEAD_DIM: tl.constexpr,
+def triton_kernel(
+    Q,
+    K,
+    V,
+    Sinks,
+    sm_scale,
+    Out,
+    H,
+    HEAD_DIM: tl.constexpr,

And update the call site:

triton_kernel[grid](
    TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
    TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
    TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
    Sinks,
    1.0 / head_dim**0.5,
    TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
-    bs,
    n_heads,
-    N_Q_CTX=seq_q,
-    N_KV_CTX=seq_kv,
    HEAD_DIM=head_dim,

268-268: Remove unnecessary noqa comment

The noqa: SIM108 directive is not needed here as SIM108 is not enabled in your linting configuration.

-if Sinks is not None:  # noqa: SIM108
+if Sinks is not None:
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)

195-196: Typo in inline comment

There's a formatting issue in the comment on line 195-196.

-query = query.transpose(1, 2).contiguous().unsqueeze(
-    3)  # align with the original function'sinterface
+query = query.transpose(1, 2).contiguous().unsqueeze(
+    3)  # align with the original function's interface

380-411: Consider extracting common benchmark code

Both MHA and GQA example files contain nearly identical benchmarking logic. Consider extracting this into a shared utility function to reduce duplication and improve maintainability.

Would you like me to help create a shared benchmarking utility that could be used by both example files? This would reduce code duplication and make future maintenance easier.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eef32b3 and 833e603.

📒 Files selected for processing (2)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (5)
  • flashattn (30-191)
  • ref_program (196-239)
  • triton_kernel (243-326)
  • triton_program (329-355)
  • gen_inputs (358-364)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-89)
🪛 Ruff (0.13.1)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

244-244: Unused function argument: Z

(ARG001)


246-246: Unused function argument: N_Q_CTX

(ARG001)


247-247: Unused function argument: N_KV_CTX

(ARG001)


261-261: Unused noqa directive (non-enabled: SIM108)

Remove unused noqa directive

(RUF100)

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

250-250: Unused function argument: Z

(ARG001)


252-252: Unused function argument: N_Q_CTX

(ARG001)


253-253: Unused function argument: N_KV_CTX

(ARG001)


268-268: Unused noqa directive (non-enabled: SIM108)

Remove unused noqa directive

(RUF100)

🔇 Additional comments (4)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)

153-154: Fixed sinks fragment allocation shape.

Great! The previous issue with incorrect allocation shape for sinks fragment buffer has been resolved. The buffer is now correctly allocated with shape [block_M] to match its usage pattern.

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (3)

146-146: Fixed sinks fragment allocation shape.

Good! The sinks fragment allocation issue has been properly addressed with shape [block_M] matching its usage pattern.


244-247: Clean up unused Triton kernel parameters

Similar to the GQA file, the Triton kernel has unused parameters that should be removed.


261-261: Remove unnecessary noqa comment

The noqa: SIM108 directive should be removed.



def gen_inputs(B, H, Sq, Skv, D,
groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing groups parameter in gen_inputs signature

The gen_inputs function accepts a groups parameter but it's not used when called from main on line 412. You'll need to pass the groups parameter there.

Apply this fix:

-Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim)
+Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups)
🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py around
line 359, gen_inputs declares a groups parameter but the call in main at line
412 does not pass it; update the call at line 412 to supply the groups argument
(e.g., pass the local variable named groups or the appropriate value/constant
you intended) so gen_inputs receives and can use groups.

@LeiWang1999 LeiWang1999 merged commit d9a171c into tile-ai:main Sep 23, 2025
4 of 6 checks passed
@Rachmanino Rachmanino deleted the attention-sink branch September 23, 2025 16:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants