- 
                Notifications
    You must be signed in to change notification settings 
- Fork 292
          [Example] Add support for bfloat16 and user-defined sm_scale in attention sink examples
          #924
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| WalkthroughAdds runtime-selectable dtype (float16|bfloat16) and optional sm_scale to multiple attention-sink and MHA examples, propagates dtype through kernel builds, input generation, ref programs, tests and CLI (--dtype); and removes several backward-path seq_len bounds checks around atomic_add updates in multiple backward examples. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant CLI as CLI (--dtype/--sm_scale)
  participant Main as main()
  participant Gen as gen_inputs(torch_dtype)
  participant Builder as kernel builder (flashattn(dtype,sm_scale))
  participant Kern as TileLang Kernel
  participant Ref as ref_program(torch_dtype)
  CLI->>Main: parse args (--dtype, --sm_scale)
  Main->>Gen: create Q,K,V,sinks (torch_dtype)
  Main->>Builder: build kernel (dtype, sm_scale)
  Main->>Kern: invoke kernel(Q,K,V,sinks)
  Kern-->>Main: O (kernel dtype)
  Main->>Ref: run ref_program(Q,K,V,sinks, dtype=torch_dtype)
  Ref-->>Main: O_ref (cast to torch_dtype)
  Main-->>CLI: report metrics and checks
sequenceDiagram
  autonumber
  participant Main as main(dtype)
  participant Prep as flashattn_bwd_preprocess(dtype)
  participant Bwd as flashattn_bwd(dtype, sm_scale)
  participant Atomic as atomic_add loop
  participant Post as flashattn_bwd_postprocess(dtype)
  Main->>Prep: prepare buffers/state (dtype-aware)
  Main->>Bwd: run backward (accum_dtype chosen by dtype)
  Bwd->>Atomic: loop over k,i,j — atomic_add(dQ...)
  Note right of Atomic #ffefc6: guard removed — atomic_add now unconditional
  Atomic-->>Post: partial grads
  Post-->>Main: finalized dQ,dK,dV (cast to dtype)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
 Suggested reviewers
 Poem
 Pre-merge checks and finishing touches❌ Failed checks (1 warning)
 ✅ Passed checks (2 passed)
 ✨ Finishing touches
 🧪 Generate unit tests (beta)
 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
 🧰 Additional context used🧬 Code graph analysis (2)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
 examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
 🪛 Ruff (0.13.3)examples/attention_sink/example_gqa_sink_bwd_bhsd.py19-19: Avoid specifying long messages outside the exception class (TRY003) examples/attention_sink/example_mha_sink_bwd_bhsd.py19-19: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
 🔇 Additional comments (11)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
…pport in both GQA and MHA backward implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
examples/attention_sink/README.md (1)
3-3: Use descriptive link text.Markdownlint (MD059) warns about generic “here”; please name the target (e.g., “optimized Triton implementation”).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
- examples/attention_sink/README.md(2 hunks)
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py(16 hunks)
- examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py(9 hunks)
- examples/attention_sink/example_mha_sink_bwd_bhsd.py(15 hunks)
- examples/attention_sink/example_mha_sink_fwd_bhsd.py(7 hunks)
- examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (6)
flashattn(33-205)
gen_inputs(373-385)
ref_program(210-254)
triton_program(344-370)
main(140-203)
main(388-464)examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
flashattn(25-186)
gen_inputs(238-249)
ref_program(190-235)
main(127-184)
main(252-311)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
gen_inputs(238-249)
flashattn(25-186)
ref_program(190-235)
main(127-184)
main(252-311)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (5)
gen_inputs(364-375)
flashattn(29-198)
ref_program(203-248)
main(133-196)
main(378-445)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (5)
flashattn(33-205)
gen_inputs(373-385)
ref_program(210-254)
main(140-203)
main(388-464)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (5)
flashattn(29-198)
gen_inputs(364-375)
ref_program(203-248)
main(133-196)
main(378-445)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
jit(240-313)tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (9)
flashattn_fwd(27-138)
flashattn_bwd_preprocess(147-172)
flashattn_bwd_postprocess(187-204)
flashattn_bwd(212-322)
get_bwd_configs(10-18)
flashattn_bwd_dsink(326-351)
main(451-520)
ref_program(404-448)
backward(375-396)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
tilelang/jit/__init__.py (1)
jit(240-313)
🪛 markdownlint-cli2 (0.18.1)
examples/attention_sink/README.md
3-3: Link text should be descriptive
(MD059, descriptive-link-text)
🪛 Ruff (0.13.3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
18-18: Avoid specifying long messages outside the exception class
(TRY003)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
18-18: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
- GitHub Check: bot-task
| dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device) | ||
| dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device) | ||
| kernel(q, k, v, do, lse, delta, dq, dk, dv) | ||
| dq = kernel_post(dq) | ||
|  | ||
| kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX) | ||
| kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) | ||
| dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) | ||
| return dq, dk, dv, dsinks, None, None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure gradients match the input dtype before returning.
dk and dv remain torch.float32 accumulators when returned from backward, but the corresponding inputs (K, V) are created in float16/bfloat16. PyTorch expects gradients in the same dtype; returning float32 here can break gradient checks or trigger type errors. Convert dk/dv back to the original dtype (mirroring the dQ postprocess) before returning from _attention.backward.
🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_bwd_bhsd.py around lines 389 to 396,
dk and dv are created as float32 but returned to match K/V which may be
float16/bfloat16; convert dk and dv back to the attention input dtype before
returning (mirror dq postprocessing) by casting them to the same dtype variable
used for inputs (e.g., dk = dk.to(dtype) and dv = dv.to(dtype)) so gradients
have the correct dtype.
| kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) | ||
| dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) | ||
| return dq, dk, dv, dsinks, None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cast dsinks back to the input dtype before returning.
flashattn_bwd_dsink now produces float32 (accumulation dtype), but autograd expects the gradient for sinks to match sinks.dtype (float16/bfloat16). Return path currently hands PyTorch a float32 tensor, causing silent dtype promotion and breaking mixed-precision assumptions. Convert before returning:
-        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
-        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
+        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
+        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) | |
| dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) | |
| return dq, dk, dv, dsinks, None | |
| kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) | |
| dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype) | |
| return dq, dk, dv, dsinks, None | 
🤖 Prompt for AI Agents
In examples/attention_sink/example_mha_sink_bwd_bhsd.py around lines 392 to 394,
dsinks is produced in float32 (accumulation dtype) but returned directly,
causing PyTorch to receive a float32 gradient while sinks has float16/bfloat16;
convert dsinks back to the input gradient dtype before returning (e.g., cast to
sinks.dtype or use .type_as(sinks)/.to(sinks.dtype)) so the returned gradient
matches sinks' original dtype and preserves mixed-precision behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
389-391: Castdsinksback to the input dtype before returning.This issue was previously flagged and remains unresolved. The
flashattn_bwd_dsinkkernel now producesfloat32(accumulation dtype) at line 335, but autograd expects the gradient forsinksto matchsinks.dtype(float16/bfloat16). The current implementation returns afloat32tensor directly, causing silent dtype promotion and breaking mixed-precision assumptions.Apply this diff:
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) -dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) +dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(q.dtype)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
207-322: Critical: Add postprocessing to convertdKanddVfromaccum_dtypetodtype.The kernel signature declares
dKanddVasaccum_dtype(float32) for better numerical stability during atomic accumulation (lines 243-244). However, the backward method (line 396) returns these float32 gradients directly without converting them to match the input dtype ofKandV(which are float16/bfloat16). PyTorch's autograd expects gradients to have the same dtype as their corresponding inputs, so this mismatch will cause type errors or break gradient checks.Similar to how
dQis postprocessed (line 392:dq = kernel_post(dq)), you need to convertdkanddvback to the original dtype before returning.Apply this diff to add the dtype conversion before returning:
kernel(q, k, v, do, lse, delta, dq, dk, dv) dq = kernel_post(dq) + # Convert dk and dv from float32 accumulation back to input dtype + torch_dtype = q.dtype + dk = dk.to(torch_dtype) + dv = dv.to(torch_dtype) + kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) return dq, dk, dv, dsinks, None, None
🧹 Nitpick comments (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
325-351: Consider usingaccum_dtypefordsinksto match MHA and improve precision.The
dsinkstensor anddsink_fragmentare allocated withdtype(lines 335, 341), but the computation (line 347-348) involvesexp2and multiplication withaccum_dtypevalues (delta_fragment). The result is then summed in the backward method (line 395:.sum(0).sum(1)), which could benefit from higher precision accumulation.For consistency with the MHA version (see
example_mha_sink_bwd_bhsd.py), consider usingaccum_dtypefor bothdsinksanddsink_fragment. This would improve numerical stability, especially when accumulated across batches and sequences.Apply this diff:
@T.prim_func def flash_bwd_dsink( Sinks: T.Tensor([heads], dtype), # type: ignore Delta: T.Tensor(shape, accum_dtype), # type: ignore lse: T.Tensor(shape, accum_dtype), # type: ignore - dsinks: T.Tensor(shape, dtype), # type: ignore + dsinks: T.Tensor(shape, accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): sink = T.alloc_local([1], dtype) lse_fragment = T.alloc_fragment([block], accum_dtype) delta_fragment = T.alloc_fragment([block], accum_dtype) - dsink_fragment = T.alloc_fragment([block], dtype) + dsink_fragment = T.alloc_fragment([block], accum_dtype)Then convert the result to dtype after the sum operation in the backward method if needed.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py(16 hunks)
- examples/attention_sink/example_mha_sink_bwd_bhsd.py(15 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
jit(240-313)tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (8)
flashattn_fwd(27-138)
flashattn_bwd_preprocess(147-172)
flashattn_bwd_postprocess(187-204)
flashattn_bwd(212-322)
get_bwd_configs(10-18)
flashattn_bwd_dsink(326-351)
main(451-520)
ref_program(404-448)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
jit(240-313)tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)examples/attention_sink/example_mha_sink_bwd_bhsd.py (10)
flashattn_fwd(27-135)
flashattn_bwd_preprocess(144-169)
flashattn_bwd_postprocess(184-201)
flashattn_bwd(209-322)
flash_bwd(232-320)
flashattn_bwd_dsink(326-351)
maybe_contiguous(371-374)
ref_program(399-444)
backward(367-391)
main(447-514)
🪛 Ruff (0.13.3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
18-18: Avoid specifying long messages outside the exception class
(TRY003)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
18-18: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-amd
- GitHub Check: format-check
- GitHub Check: format-check
🔇 Additional comments (17)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (12)
27-38: LGTM! Dtype parameter added correctly.The function signature now accepts a
dtypeparameter with a sensible default, enabling runtime dtype selection for the forward kernel.
43-45: LGTM! Configurable sm_scale with sensible default.The sm_scale now defaults to
1/sqrt(dim)when not provided, and the log2(e) conversion is applied consistently. This aligns with standard attention implementations.
138-144: LGTM! Backward preprocess kernel updated consistently.The JIT decorator and function signature are updated to match the forward kernel's dtype support pattern.
178-184: LGTM! Backward postprocess kernel updated consistently.Consistent with other kernel updates for dtype support and optimization flags.
204-216: LGTM! Main backward kernel updated consistently.The signature now includes both
sm_scaleanddtypeparameters, matching the forward kernel's interface.
384-385: LGTM! Correct dtype for dk and dv allocations.Unlike
dqwhich requiresfloat32for atomic accumulation,dkanddvcan use the input dtype directly since they're written once per location. This change aligns with the dtype-aware design.
403-404: LGTM! Reference implementation now dtype-aware.The
ref_programsignature now accepts adtypeparameter, enabling dtype-specific validation.
417-443: LGTM! Reference implementation respects specified dtype.The sinks tensor no longer forces a cast to float (line 417), and the final output is cast to the specified dtype (line 443). This correctly aligns the reference implementation with the dtype-aware kernel behavior.
451-453: LGTM! Main function now accepts dtype configuration.The function signature accepts a dtype string and maps it to the corresponding PyTorch dtype, enabling end-to-end dtype configuration.
464-467: LGTM! Input tensors created with specified dtype.All input tensors (
Q,K,V,sinks) now use the configurabletorch_dtype, enabling proper dtype testing.
485-498: LGTM! Dtype-specific tolerances for numerical validation.The assertions now use appropriate tolerance values for each dtype: tighter for float16 (1e-2) and more relaxed for bfloat16 (2e-2), reflecting bfloat16's reduced mantissa precision. All gradient checks are covered.
528-531: LGTM! CLI wired for dtype configuration.The
--dtypeargument is properly added to the CLI and passed through tomain(), completing the end-to-end dtype configuration path.examples/attention_sink/example_gqa_sink_bwd_bhsd.py (5)
21-46: LGTM!The addition of configurable
sm_scaleanddtypeparameters is well-implemented. The defaultsm_scale = 1.0 / sqrt(dim)is correct for scaled dot-product attention, and the log2(e) conversion for the scale is appropriate for the exp2-based softmax implementation.
141-172: LGTM!The dtype parameter is correctly propagated through the preprocessing kernel, maintaining consistency with the forward pass.
181-204: LGTM!The postprocessing step correctly converts the accumulated
dQgradients fromfloat32to the specified dtype, maintaining gradient precision requirements.
404-448: LGTM!The reference implementation correctly handles the dtype parameter, ensuring the output tensor matches the specified precision.
451-538: LGTM!The main function and CLI properly handle dtype selection:
- Correct torch dtype mapping
- Appropriate tolerance thresholds for each precision (float16: 1e-2, bfloat16: 2e-2)
- Clean argument parsing and propagation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, we're good to go if we can refactor the codes that I've annotated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
75-83: Buffer overrun risk: sinks fragment sized by heads but indexed by block_M.You allocate sinks with length [heads] but write with i in range(block_M). This can OOB or read wrong values. Use per-row buffer sized [block_M].
- sinks = T.alloc_fragment([heads], dtype) + sinks = T.alloc_fragment([block_M], dtype) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Parallel(block_M): sinks[i] = Sinks[by]
♻️ Duplicate comments (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
388-396: Cast dk/dv grads to input dtype before returning.Gradients for K/V are accumulated in float32 and returned as float32; Autograd expects them to match inputs.
kernel(q, k, v, do, lse, delta, dq, dk, dv) dq = kernel_post(dq) kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) - return dq, dk, dv, dsinks, None, None + # Match gradient dtypes with inputs + dk = dk.to(k.dtype) + dv = dv.to(v.dtype) + return dq, dk, dv, dsinks, None, Noneexamples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
389-391: Cast dsinks to the sinks dtype before returning.Kernel produces float32 accumulation; return should match sinks.dtype.
- kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) - dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) + kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype)
🧹 Nitpick comments (8)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)
345-346: Use Optional[int] for Python 3.8 compatibility.Replace PEP 604 unions with typing.Optional to support py38 (consistent with other files).
-def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor: +def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:-def main( +def main( batch: int = 1, heads: int = 32, seq_q: int = 256, seq_kv: int = 256, dim: int = 128, groups: int = 8, - window_size: int | None = None, + window_size: Optional[int] = None, dtype: str = "float16", tune: bool = False, ):[Based on learnings]
Also applies to: 396-398
439-444: Loosen tolerances for bfloat16 to avoid false negatives.bf16 needs higher tolerances; mirror other examples.
- torch.testing.assert_close( - kernel(Q, K, V, sinks), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2) + torch.testing.assert_close( + kernel(Q, K, V, sinks), + ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), + rtol=2e-2 if dtype == "bfloat16" else 1e-2, + atol=2e-2 if dtype == "bfloat16" else 1e-2)- if torch.allclose( - triton_program(Q, K, V, sinks, window_size), - ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), - rtol=1e-2, - atol=1e-2): + if torch.allclose( + triton_program(Q, K, V, sinks, window_size), + ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), + rtol=2e-2 if dtype == "bfloat16" else 1e-2, + atol=2e-2 if dtype == "bfloat16" else 1e-2):Also applies to: 446-454
481-483: Restrict --dtype to valid choices.Add argparse choices to prevent KeyError on invalid input.
-parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") +parser.add_argument( + '--dtype', + type=str, + choices=["float16", "bfloat16"], + default="float16", + help="dtype")examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
325-351: Consider accumulating dsinks in float32 but return in input dtype.Currently dsinks kernel writes dtype, but intermediate accumulators are float; casting on return is safer/consistent with other paths.
If you keep dsinks as accum_dtype inside the kernel, return-time cast keeps API consistent:
- dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) + dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype)
456-458: Use Optional[int] for Python 3.8 compatibility.Replace PEP 604 unions.
-def main(BATCH: int = 1, - H: int = 8, - N_CTX: int = 512, - D_HEAD: int = 64, - groups: int = 2, - window_size: int | None = None, - dtype: str = "float16"): +def main(BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16"):
536-537: Restrict --dtype to valid choices.Add argparse choices to guard against invalid inputs.
-parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") +parser.add_argument( + '--dtype', + type=str, + choices=["float16", "bfloat16"], + default="float16", + help="dtype")examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
451-452: Use Optional[int] for Python 3.8 compatibility.Replace PEP 604 union.
-def main(BATCH: int = 1, +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, - window_size: int | None = None, + window_size: Optional[int] = None, dtype: str = "float16"):
529-530: Restrict --dtype to valid choices.Add argparse choices.
-parser.add_argument( - '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16") +parser.add_argument( + '--dtype', + type=str, + choices=["float16", "bfloat16"], + default="float16", + help="dtype")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
- examples/amd/example_amd_flash_attn_bwd.py(1 hunks)
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py(16 hunks)
- examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py(10 hunks)
- examples/attention_sink/example_mha_sink_bwd_bhsd.py(16 hunks)
- examples/attention_sink/example_mha_sink_fwd_bhsd.py(8 hunks)
- examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py(9 hunks)
- examples/flash_attention/example_gqa_bwd.py(2 hunks)
- examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py(2 hunks)
- examples/flash_attention/example_mha_bwd_bhsd.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/attention_sink/example_mha_sink_fwd_bhsd.py
🧰 Additional context used
🧬 Code graph analysis (6)
examples/flash_attention/example_gqa_bwd.py (1)
tilelang/language/atomic.py (1)
atomic_add(116-228)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1)
tilelang/language/atomic.py (1)
atomic_add(116-228)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
gen_inputs(239-250)
flashattn(26-187)
ref_program(191-236)
main(128-185)
main(253-312)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (5)
gen_inputs(365-376)
flashattn(30-199)
ref_program(204-249)
main(134-197)
main(379-446)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (6)
flashattn(34-206)
gen_inputs(374-386)
ref_program(211-255)
triton_program(345-371)
main(141-204)
main(389-465)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (9)
get_bwd_configs(11-19)
flashattn_fwd(28-139)
flashattn_bwd_preprocess(148-173)
flashattn_bwd_postprocess(188-205)
flashattn_bwd(213-322)
flashattn_bwd_dsink(326-351)
main(451-520)
ref_program(404-448)
backward(375-396)tilelang/jit/__init__.py (1)
jit(240-313)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (11)
get_bwd_configs(11-19)
flashattn_fwd(28-136)
flashattn_bwd_preprocess(145-170)
flashattn_bwd_postprocess(185-202)
flashattn_bwd(210-322)
flash_bwd(233-320)
flashattn_bwd_dsink(326-351)
maybe_contiguous(371-374)
ref_program(399-444)
backward(367-391)
main(447-514)
🪛 Ruff (0.13.3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
19-19: Avoid specifying long messages outside the exception class
(TRY003)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
19-19: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
| head_kv = heads // groups | ||
| q_shape = [batch, heads, seq_len, dim] | ||
| kv_shape = [batch, head_kv, seq_len, dim] | ||
| dtype = "float16" | ||
| accum_dtype = "float" | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Guard against invalid groups configuration.
Add a divisibility check to avoid silent floor when mapping heads to key/value groups.
     head_kv = heads // groups
+    assert heads % groups == 0, "heads must be divisible by groups"
     q_shape = [batch, heads, seq_len, dim]
     kv_shape = [batch, head_kv, seq_len, dim]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| head_kv = heads // groups | |
| q_shape = [batch, heads, seq_len, dim] | |
| kv_shape = [batch, head_kv, seq_len, dim] | |
| dtype = "float16" | |
| accum_dtype = "float" | |
| head_kv = heads // groups | |
| assert heads % groups == 0, "heads must be divisible by groups" | |
| q_shape = [batch, heads, seq_len, dim] | |
| kv_shape = [batch, head_kv, seq_len, dim] | |
| accum_dtype = "float" | 
🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_bwd_bhsd.py around lines 49 to 53,
the code computes head_kv = heads // groups without checking divisibility which
can silently floor the result; add a guard that validates groups > 0 and heads %
groups == 0 and raise a clear exception (e.g., ValueError) with a message like
"heads must be divisible by groups" if the check fails so the caller is informed
instead of getting an incorrect head_kv; keep the subsequent q_shape/kv_shape
computation unchanged.
| head_kv = heads // groups | ||
| q_shape = [batch, heads, seq_q, dim] | ||
| kv_shape = [batch, head_kv, seq_kv, dim] | ||
| dtype = "float16" | ||
| accum_dtype = "float" | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against invalid groups configuration.
Add an explicit check to prevent silent floor division when heads is not divisible by groups.
     head_kv = heads // groups
+    assert heads % groups == 0, "heads must be divisible by groups"
     q_shape = [batch, heads, seq_q, dim]
     kv_shape = [batch, head_kv, seq_kv, dim]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| head_kv = heads // groups | |
| q_shape = [batch, heads, seq_q, dim] | |
| kv_shape = [batch, head_kv, seq_kv, dim] | |
| dtype = "float16" | |
| accum_dtype = "float" | |
| head_kv = heads // groups | |
| assert heads % groups == 0, "heads must be divisible by groups" | |
| q_shape = [batch, heads, seq_q, dim] | |
| kv_shape = [batch, head_kv, seq_kv, dim] | |
| accum_dtype = "float" | 
🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py around
lines 57 to 61, add a guard to validate the groups value before computing
head_kv: ensure groups is a positive non-zero integer and that heads is
divisible by groups (heads % groups == 0); if the check fails, raise a clear
ValueError (or assert) explaining that heads must be divisible by groups to
avoid silent floor-division. This prevents incorrect head_kv computation and
surfaces configuration errors early.
| for i, j in T.Parallel(block_N, dim_qk): | ||
| if k * block_N + i < seq_len: | ||
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | ||
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reinstate seq_len guard before atomic_add into dQ (atomic kernel).
With the guard removed, the final tile on sequences where seq_len % block_N ≠ 0 now updates indices beyond dQ’s extent (e.g., indices ≥ seq_len). Please add back the predicate (if k * block_N + i < seq_len) or otherwise clamp the loop so the atomic write never targets rows outside the buffer.
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 247
to 248, the atomic_add into dQ lacks a seq_len guard and can write past dQ when
seq_len % block_N != 0; restore a predicate so the atomic write only happens for
valid rows (e.g., wrap the T.atomic_add call in a conditional checking if k *
block_N + i < seq_len, or clamp the loop bounds accordingly) to ensure no atomic
writes target indices >= seq_len.
| for i, j in T.Parallel(block_N, dim_qk): | ||
| if k * block_N + i < seq_len: | ||
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | ||
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not drop the tail-bound check in the split kernel either.
The split-path loop still iterates over the full block_N, so without the previous conditional it will issue atomic_add on out-of-range indices for partial tail blocks. Please restore the k * block_N + i < seq_len guard (or equivalent bound enforcement) around this atomic update too.
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines
363-364, the split-path loop issues atomic_add for full block_N even on tail
blocks, causing out-of-range writes; restore a guard checking k * block_N + i <
seq_len (or equivalent bounds enforcement) around the T.atomic_add so the atomic
update only executes for valid sequence indices, ensuring the loop can still
iterate block_N elements but skips/guards updates for tail elements.
| for i, j in T.Parallel(block_N, dim_qk): | ||
| if k * block_N + i < seq_len: | ||
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | ||
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reintroduce the seq_len guard before the dQ atomic_add.
Removing the bounds check means tail tiles now write past the end of dQ whenever seq_len is not an exact multiple of block_N. For example, with seq_len=130 and block_N=64, the last iteration issues atomic_add on indices [128..191] even though only [128,129] exist, which is undefined behaviour on GPU. Please restore the if k * block_N + i < seq_len guard (or clamp the loop bounds) around the atomic update so we never address past the allocated extent.
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd.py around lines 237-238, the
atomic_add to dQ can write past the allocated sequence length for tail tiles;
restore a guard so we never index beyond seq_len. Wrap the T.atomic_add(dQ[...],
...) call with a conditional check if k * block_N + i < seq_len (or adjust the
loop bounds to clamp i so indices < seq_len), ensuring the atomic update only
executes for valid positions.
| for i, j in T.Parallel(block_N, dim): | ||
| if k * block_N + i < seq_len: | ||
| T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) | ||
| T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent tail-tile writes past seq_len in dQ.
This kernel also iterates a full block_N, so without the original if k * block_N + i < seq_len guard the atomic add writes beyond the valid sequence rows for trailing tiles (e.g., seq_len=130, block_N=64). Please reinstate the boundary predicate (or adjust loop bounds) before performing atomic_add.
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd_bhsd.py around lines 231-232, the
atomic_add into dQ iterates the full block_N and can write past seq_len for
trailing tiles; reinstate the original boundary predicate (e.g., check if k *
block_N + i < seq_len) or change the loop bounds so i only ranges over valid
rows before calling T.atomic_add, ensuring the guard is evaluated per iteration
and prevents writes beyond the sequence length.
This pull request introduces support for configurable data types (float16 and bfloat16) for attention kernels in the example scripts, improving flexibility and compatibility for different hardware and precision requirements. It also refactors kernel functions and input generation to accept a
dtypeparameter, ensures reference implementations use the correct precision, and adds command-line options for specifying data type. Additionally, minor improvements to kernel configuration and memory allocation are included.Data type configurability and propagation:
flashattn_fwd,flashattn_bwd, preprocess/postprocess functions, and input generators) now accept adtypeargument, defaulting to"float16", allowing easy switching between float16 and bfloat16. [1] [2] [3]dtypeparameter, including correct mapping to PyTorch types. [1] [2] [3]Reference and test improvements:
ref_program) and test assertions now use the specifieddtypefor outputs and comparisons, ensuring consistency between kernel and reference results. [1] [2] [3]Kernel and memory allocation fixes:
dtype, matching the input precision. [1] [2] [3]example_mha_sink_bwd_bhsd.py, sinks are allocated withaccum_dtypeinstead ofdtypefor improved numerical stability.Kernel configuration and compilation:
-DENABLE_BF16) and fast math optimizations, ensuring kernels are built with appropriate precision support. [1] [2]Miscellaneous improvements:
example_mha_sink_bwd_bhsd.pyfor reproducibility. [1] [2]Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Documentation