Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Oct 1, 2025

This pull request introduces support for configurable data types (float16 and bfloat16) for attention kernels in the example scripts, improving flexibility and compatibility for different hardware and precision requirements. It also refactors kernel functions and input generation to accept a dtype parameter, ensures reference implementations use the correct precision, and adds command-line options for specifying data type. Additionally, minor improvements to kernel configuration and memory allocation are included.

Data type configurability and propagation:

  • All major attention kernel functions (flashattn_fwd, flashattn_bwd, preprocess/postprocess functions, and input generators) now accept a dtype argument, defaulting to "float16", allowing easy switching between float16 and bfloat16. [1] [2] [3]
  • The main entry points and CLI interfaces for each example script are updated to accept and propagate the dtype parameter, including correct mapping to PyTorch types. [1] [2] [3]

Reference and test improvements:

  • Reference implementations (ref_program) and test assertions now use the specified dtype for outputs and comparisons, ensuring consistency between kernel and reference results. [1] [2] [3]

Kernel and memory allocation fixes:

  • Memory allocations for intermediate tensors (such as gradients and sinks) now use the correct dtype, matching the input precision. [1] [2] [3]
  • In example_mha_sink_bwd_bhsd.py, sinks are allocated with accum_dtype instead of dtype for improved numerical stability.

Kernel configuration and compilation:

  • Kernel compilation flags now include support for bfloat16 (-DENABLE_BF16) and fast math optimizations, ensuring kernels are built with appropriate precision support. [1] [2]

Miscellaneous improvements:

  • Minor code cleanups, such as removing commented code and disabling cache in example_mha_sink_bwd_bhsd.py for reproducibility. [1] [2]

Summary by CodeRabbit

  • New Features

    • Multi-dtype support (float16 and bfloat16) across attention examples; runtime/tests honor chosen dtype and new --dtype CLI option; optional sm_scale for dynamic scaling.
  • Refactor

    • Propagated dtype through inputs, kernels, benchmarks, and validations; added BF16 compile flags and adjusted defaults/configs per SM/version.
  • Bug Fixes

    • Removed backward-path bounds guards around atomic updates in multiple examples (changes backward kernel control flow).
  • Documentation

    • Updated links, switched benchmark dtype to bf16, refreshed performance tables, and clarified backward-optimization notes.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 1, 2025

Walkthrough

Adds runtime-selectable dtype (float16|bfloat16) and optional sm_scale to multiple attention-sink and MHA examples, propagates dtype through kernel builds, input generation, ref programs, tests and CLI (--dtype); and removes several backward-path seq_len bounds checks around atomic_add updates in multiple backward examples.

Changes

Cohort / File(s) Summary
Documentation
examples/attention_sink/README.md
Updated Triton link formatting; switched benchmark dtype to bf16; refreshed performance table values; softened backward-optimization note.
Attention-sink — GQA backward (dtype propagation)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py
Added dtype param to forward/backward/pre/postprocess and kernel signatures; propagate dtype to ref_program, tests and CLI (--dtype); added BF16 compile flags and dtype-aware accumulator/layout handling.
Attention-sink — GQA forward (wgmma pipelined, dtype + sm_scale)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
Added dtype and optional sm_scale to flashattn; propagated dtype to gen_inputs, ref_program, main, kernel builds and comparisons; CLI --dtype added.
Attention-sink — MHA backward (configs + dtype)
examples/attention_sink/example_mha_sink_bwd_bhsd.py
Adjusted SM-specific config tuples; added dtype and optional sm_scale across fwd/bwd/pre/postprocess/dsink; enabled BF16 compile flags; ref_program/main accept dtype; CLI --dtype added.
Attention-sink — MHA forward (dtype + sm_scale)
examples/attention_sink/example_mha_sink_fwd_bhsd.py, examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
Added dtype and optional sm_scale to flashattn; updated gen_inputs, ref_program, main to accept/propagate dtype; benchmarking/validation use selected dtype; CLI --dtype added.
Flash-attention / AMD examples — backward atomic_add bounds removal
examples/amd/example_amd_flash_attn_bwd.py, examples/flash_attention/example_gqa_bwd.py, examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py, examples/flash_attention/example_mha_bwd_bhsd.py
Removed conditional seq_len guards around T.atomic_add updates in multiple backward kernels so atomic_add executes unconditionally, changing control flow and potentially enabling out-of-bounds updates for positions beyond seq_len.
Misc — typing & CLI wiring
examples/... (multiple examples)
Added Optional typing imports, extended function signatures to include dtype parameters, and added --dtype CLI arguments across examples.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as CLI (--dtype/--sm_scale)
  participant Main as main()
  participant Gen as gen_inputs(torch_dtype)
  participant Builder as kernel builder (flashattn(dtype,sm_scale))
  participant Kern as TileLang Kernel
  participant Ref as ref_program(torch_dtype)

  CLI->>Main: parse args (--dtype, --sm_scale)
  Main->>Gen: create Q,K,V,sinks (torch_dtype)
  Main->>Builder: build kernel (dtype, sm_scale)
  Main->>Kern: invoke kernel(Q,K,V,sinks)
  Kern-->>Main: O (kernel dtype)
  Main->>Ref: run ref_program(Q,K,V,sinks, dtype=torch_dtype)
  Ref-->>Main: O_ref (cast to torch_dtype)
  Main-->>CLI: report metrics and checks
Loading
sequenceDiagram
  autonumber
  participant Main as main(dtype)
  participant Prep as flashattn_bwd_preprocess(dtype)
  participant Bwd as flashattn_bwd(dtype, sm_scale)
  participant Atomic as atomic_add loop
  participant Post as flashattn_bwd_postprocess(dtype)

  Main->>Prep: prepare buffers/state (dtype-aware)
  Main->>Bwd: run backward (accum_dtype chosen by dtype)
  Bwd->>Atomic: loop over k,i,j — atomic_add(dQ...)
  Note right of Atomic #ffefc6: guard removed — atomic_add now unconditional
  Atomic-->>Post: partial grads
  Post-->>Main: finalized dQ,dK,dV (cast to dtype)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • tzj-fxz
  • LeiWang1999

Poem

I nibble through bits with a twitch and a hop,
float or bfloat, I pick which to stop.
Kernels shaped like carrots in my pack,
forward, backward — I stitch the track.
Hooray for dtypes! I bound and then hop back 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly captures the primary modifications introduced by the pull request—adding bfloat16 support and a user-defined sm_scale parameter to the attention sink examples—and accurately reflects the main change without extraneous detail.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c8a376c and 422b301.

📒 Files selected for processing (2)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (16 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (16 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (10)
  • get_bwd_configs (11-19)
  • flashattn_fwd (28-136)
  • flashattn_bwd_preprocess (145-170)
  • flashattn_bwd_postprocess (185-202)
  • flashattn_bwd (210-322)
  • flash_bwd (233-320)
  • maybe_contiguous (371-374)
  • ref_program (399-444)
  • backward (367-391)
  • main (447-514)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (9)
  • get_bwd_configs (11-19)
  • flashattn_fwd (28-139)
  • flashattn_bwd_preprocess (148-173)
  • flashattn_bwd_postprocess (188-205)
  • flashattn_bwd (213-321)
  • flashattn_bwd_dsink (325-350)
  • main (450-519)
  • ref_program (403-447)
  • backward (374-395)
tilelang/jit/__init__.py (1)
  • jit (240-313)
🪛 Ruff (0.13.3)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py

19-19: Avoid specifying long messages outside the exception class

(TRY003)

examples/attention_sink/example_mha_sink_bwd_bhsd.py

19-19: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check
🔇 Additional comments (11)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (6)

8-8: LGTM!

The Optional import is correctly used for type hints in the ref_program signature.


34-46: LGTM!

The sm_scale parameter addition and default computation are correctly implemented. The default scaling factor (1.0 / dim)**0.5 follows standard attention mechanics, and the scale is consistently applied in both forward and backward passes.

Also applies to: 216-224


23-27: LGTM!

The JIT decorator updates correctly enable bfloat16 support via the -DENABLE_BF16 compile flag, aligning with the PR objectives. The TL_ENABLE_FAST_MATH configuration is appropriately set in pass configs.

Also applies to: 140-144, 180-184, 205-209


384-385: LGTM! Valid optimization.

Changing dk and dv from torch.zeros to torch.empty and using q.dtype instead of float32 is correct. Unlike dq (which uses atomic_add and requires zero-initialization with float32 accumulation), dk and dv are written directly via T.copy operations (lines 319-320), making zero-initialization unnecessary and allowing the use of input dtype.


403-404: LGTM!

The dtype parameter addition to ref_program and the corresponding output casting (line 443) correctly implement dtype-aware behavior. Removing the .float() call on sinks (line 417) is appropriate since the computation handles the dtype naturally.

Also applies to: 417-417, 443-443


447-498: LGTM!

The dtype parameter integration in main is complete and correct:

  • Proper torch dtype mapping (line 453)
  • All tensors created with the specified dtype (lines 464-467, 477)
  • Test tolerances appropriately relaxed for bfloat16 due to its reduced precision (lines 485-488)
  • CLI argument added for user control (lines 528-529)

Also applies to: 517-531

examples/attention_sink/example_gqa_sink_bwd_bhsd.py (5)

8-8: LGTM!

Adding Optional import supports the type hints used later in the file.


244-245: Verify accum_dtype usage for dK/dV tensors.

The kernel now uses accum_dtype for dK/dV tensor signatures and their shared buffers, which differs from the MHA example (example_mha_sink_bwd_bhsd.py) that uses dtype for these. While using accum_dtype provides better numerical precision during gradient accumulation, ensure this is intentional and consistent with the GQA requirements.

Additionally, verify that the implicit cast from accum_dtype shared buffers to output tensors during atomic_add (lines 317, 319) is handled correctly by the compiler.

Also applies to: 262-263


457-457: LGTM!

The dtype mapping and tolerance adjustments are well-implemented:

  • Clean string-to-torch.dtype mapping (line 457)
  • Appropriate tolerance relaxation for bfloat16 (2e-2) vs float16 (1e-2) reflecting the reduced precision of these formats

Also applies to: 490-493


365-366: LGTM!

Dtype determination and propagation through the forward and backward kernels is correctly implemented, ensuring consistent precision throughout the attention computation pipeline.

Also applies to: 378-383


314-319: Validate slice bounds for all atomic_add calls
No explicit bounds checks remain around the atomic_add invocations on dQ, dV, and dK. Confirm that the computed loop_ed and the built-in slice semantics of T.copy/T.atomic_add safely handle partial blocks when seq_len % block_N or seq_len % block_M ≠ 0; reintroduce boundary guards if they do not.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Oct 1, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@Rachmanino Rachmanino marked this pull request as ready for review October 9, 2025 03:26
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
examples/attention_sink/README.md (1)

3-3: Use descriptive link text.

Markdownlint (MD059) warns about generic “here”; please name the target (e.g., “optimized Triton implementation”).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 481cae4 and e8fa8bb.

📒 Files selected for processing (6)
  • examples/attention_sink/README.md (2 hunks)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (16 hunks)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (9 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (15 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py (7 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (3)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (6)
  • flashattn (33-205)
  • gen_inputs (373-385)
  • ref_program (210-254)
  • triton_program (344-370)
  • main (140-203)
  • main (388-464)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
  • flashattn (25-186)
  • gen_inputs (238-249)
  • ref_program (190-235)
  • main (127-184)
  • main (252-311)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
  • gen_inputs (238-249)
  • flashattn (25-186)
  • ref_program (190-235)
  • main (127-184)
  • main (252-311)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (5)
  • gen_inputs (364-375)
  • flashattn (29-198)
  • ref_program (203-248)
  • main (133-196)
  • main (378-445)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (3)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (5)
  • flashattn (33-205)
  • gen_inputs (373-385)
  • ref_program (210-254)
  • main (140-203)
  • main (388-464)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (5)
  • flashattn (29-198)
  • gen_inputs (364-375)
  • ref_program (203-248)
  • main (133-196)
  • main (378-445)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
  • jit (240-313)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (9)
  • flashattn_fwd (27-138)
  • flashattn_bwd_preprocess (147-172)
  • flashattn_bwd_postprocess (187-204)
  • flashattn_bwd (212-322)
  • get_bwd_configs (10-18)
  • flashattn_bwd_dsink (326-351)
  • main (451-520)
  • ref_program (404-448)
  • backward (375-396)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
tilelang/jit/__init__.py (1)
  • jit (240-313)
🪛 markdownlint-cli2 (0.18.1)
examples/attention_sink/README.md

3-3: Link text should be descriptive

(MD059, descriptive-link-text)

🪛 Ruff (0.13.3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py

18-18: Avoid specifying long messages outside the exception class

(TRY003)

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

18-18: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
  • GitHub Check: bot-task

Comment on lines +389 to 396
dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)

kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX)
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure gradients match the input dtype before returning.

dk and dv remain torch.float32 accumulators when returned from backward, but the corresponding inputs (K, V) are created in float16/bfloat16. PyTorch expects gradients in the same dtype; returning float32 here can break gradient checks or trigger type errors. Convert dk/dv back to the original dtype (mirroring the dQ postprocess) before returning from _attention.backward.

🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_bwd_bhsd.py around lines 389 to 396,
dk and dv are created as float32 but returned to match K/V which may be
float16/bfloat16; convert dk and dv back to the attention input dtype before
returning (mirror dq postprocessing) by casting them to the same dtype variable
used for inputs (e.g., dk = dk.to(dtype) and dv = dv.to(dtype)) so gradients
have the correct dtype.

Comment on lines +392 to 394
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cast dsinks back to the input dtype before returning.

flashattn_bwd_dsink now produces float32 (accumulation dtype), but autograd expects the gradient for sinks to match sinks.dtype (float16/bfloat16). Return path currently hands PyTorch a float32 tensor, causing silent dtype promotion and breaking mixed-precision assumptions. Convert before returning:

-        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
-        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
+        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
+        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype)
return dq, dk, dv, dsinks, None
🤖 Prompt for AI Agents
In examples/attention_sink/example_mha_sink_bwd_bhsd.py around lines 392 to 394,
dsinks is produced in float32 (accumulation dtype) but returned directly,
causing PyTorch to receive a float32 gradient while sinks has float16/bfloat16;
convert dsinks back to the input gradient dtype before returning (e.g., cast to
sinks.dtype or use .type_as(sinks)/.to(sinks.dtype)) so the returned gradient
matches sinks' original dtype and preserves mixed-precision behavior.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

389-391: Cast dsinks back to the input dtype before returning.

This issue was previously flagged and remains unresolved. The flashattn_bwd_dsink kernel now produces float32 (accumulation dtype) at line 335, but autograd expects the gradient for sinks to match sinks.dtype (float16/bfloat16). The current implementation returns a float32 tensor directly, causing silent dtype promotion and breaking mixed-precision assumptions.

Apply this diff:

 kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
-dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
+dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(q.dtype)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)

207-322: Critical: Add postprocessing to convert dK and dV from accum_dtype to dtype.

The kernel signature declares dK and dV as accum_dtype (float32) for better numerical stability during atomic accumulation (lines 243-244). However, the backward method (line 396) returns these float32 gradients directly without converting them to match the input dtype of K and V (which are float16/bfloat16). PyTorch's autograd expects gradients to have the same dtype as their corresponding inputs, so this mismatch will cause type errors or break gradient checks.

Similar to how dQ is postprocessed (line 392: dq = kernel_post(dq)), you need to convert dk and dv back to the original dtype before returning.

Apply this diff to add the dtype conversion before returning:

         kernel(q, k, v, do, lse, delta, dq, dk, dv)
         dq = kernel_post(dq)
 
+        # Convert dk and dv from float32 accumulation back to input dtype
+        torch_dtype = q.dtype
+        dk = dk.to(torch_dtype)
+        dv = dv.to(torch_dtype)
+
         kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
         dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
         return dq, dk, dv, dsinks, None, None
🧹 Nitpick comments (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)

325-351: Consider using accum_dtype for dsinks to match MHA and improve precision.

The dsinks tensor and dsink_fragment are allocated with dtype (lines 335, 341), but the computation (line 347-348) involves exp2 and multiplication with accum_dtype values (delta_fragment). The result is then summed in the backward method (line 395: .sum(0).sum(1)), which could benefit from higher precision accumulation.

For consistency with the MHA version (see example_mha_sink_bwd_bhsd.py), consider using accum_dtype for both dsinks and dsink_fragment. This would improve numerical stability, especially when accumulated across batches and sequences.

Apply this diff:

     @T.prim_func
     def flash_bwd_dsink(
             Sinks: T.Tensor([heads], dtype),  # type: ignore
             Delta: T.Tensor(shape, accum_dtype),  # type: ignore
             lse: T.Tensor(shape, accum_dtype),  # type: ignore
-            dsinks: T.Tensor(shape, dtype),  # type: ignore
+            dsinks: T.Tensor(shape, accum_dtype),  # type: ignore
     ):
         with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz):
             sink = T.alloc_local([1], dtype)
             lse_fragment = T.alloc_fragment([block], accum_dtype)
             delta_fragment = T.alloc_fragment([block], accum_dtype)
-            dsink_fragment = T.alloc_fragment([block], dtype)
+            dsink_fragment = T.alloc_fragment([block], accum_dtype)

Then convert the result to dtype after the sum operation in the backward method if needed.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e8fa8bb and 0e236a0.

📒 Files selected for processing (2)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (16 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (15 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
  • jit (240-313)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (8)
  • flashattn_fwd (27-138)
  • flashattn_bwd_preprocess (147-172)
  • flashattn_bwd_postprocess (187-204)
  • flashattn_bwd (212-322)
  • get_bwd_configs (10-18)
  • flashattn_bwd_dsink (326-351)
  • main (451-520)
  • ref_program (404-448)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
tilelang/jit/__init__.py (1)
  • jit (240-313)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (10)
  • flashattn_fwd (27-135)
  • flashattn_bwd_preprocess (144-169)
  • flashattn_bwd_postprocess (184-201)
  • flashattn_bwd (209-322)
  • flash_bwd (232-320)
  • flashattn_bwd_dsink (326-351)
  • maybe_contiguous (371-374)
  • ref_program (399-444)
  • backward (367-391)
  • main (447-514)
🪛 Ruff (0.13.3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py

18-18: Avoid specifying long messages outside the exception class

(TRY003)

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

18-18: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
  • GitHub Check: format-check
🔇 Additional comments (17)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (12)

27-38: LGTM! Dtype parameter added correctly.

The function signature now accepts a dtype parameter with a sensible default, enabling runtime dtype selection for the forward kernel.


43-45: LGTM! Configurable sm_scale with sensible default.

The sm_scale now defaults to 1/sqrt(dim) when not provided, and the log2(e) conversion is applied consistently. This aligns with standard attention implementations.


138-144: LGTM! Backward preprocess kernel updated consistently.

The JIT decorator and function signature are updated to match the forward kernel's dtype support pattern.


178-184: LGTM! Backward postprocess kernel updated consistently.

Consistent with other kernel updates for dtype support and optimization flags.


204-216: LGTM! Main backward kernel updated consistently.

The signature now includes both sm_scale and dtype parameters, matching the forward kernel's interface.


384-385: LGTM! Correct dtype for dk and dv allocations.

Unlike dq which requires float32 for atomic accumulation, dk and dv can use the input dtype directly since they're written once per location. This change aligns with the dtype-aware design.


403-404: LGTM! Reference implementation now dtype-aware.

The ref_program signature now accepts a dtype parameter, enabling dtype-specific validation.


417-443: LGTM! Reference implementation respects specified dtype.

The sinks tensor no longer forces a cast to float (line 417), and the final output is cast to the specified dtype (line 443). This correctly aligns the reference implementation with the dtype-aware kernel behavior.


451-453: LGTM! Main function now accepts dtype configuration.

The function signature accepts a dtype string and maps it to the corresponding PyTorch dtype, enabling end-to-end dtype configuration.


464-467: LGTM! Input tensors created with specified dtype.

All input tensors (Q, K, V, sinks) now use the configurable torch_dtype, enabling proper dtype testing.


485-498: LGTM! Dtype-specific tolerances for numerical validation.

The assertions now use appropriate tolerance values for each dtype: tighter for float16 (1e-2) and more relaxed for bfloat16 (2e-2), reflecting bfloat16's reduced mantissa precision. All gradient checks are covered.


528-531: LGTM! CLI wired for dtype configuration.

The --dtype argument is properly added to the CLI and passed through to main(), completing the end-to-end dtype configuration path.

examples/attention_sink/example_gqa_sink_bwd_bhsd.py (5)

21-46: LGTM!

The addition of configurable sm_scale and dtype parameters is well-implemented. The default sm_scale = 1.0 / sqrt(dim) is correct for scaled dot-product attention, and the log2(e) conversion for the scale is appropriate for the exp2-based softmax implementation.


141-172: LGTM!

The dtype parameter is correctly propagated through the preprocessing kernel, maintaining consistency with the forward pass.


181-204: LGTM!

The postprocessing step correctly converts the accumulated dQ gradients from float32 to the specified dtype, maintaining gradient precision requirements.


404-448: LGTM!

The reference implementation correctly handles the dtype parameter, ensuring the output tensor matches the specified precision.


451-538: LGTM!

The main function and CLI properly handle dtype selection:

  • Correct torch dtype mapping
  • Appropriate tolerance thresholds for each precision (float16: 1e-2, bfloat16: 2e-2)
  • Clean argument parsing and propagation

LeiWang1999
LeiWang1999 previously approved these changes Oct 10, 2025
Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we're good to go if we can refactor the codes that I've annotated.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)

75-83: Buffer overrun risk: sinks fragment sized by heads but indexed by block_M.

You allocate sinks with length [heads] but write with i in range(block_M). This can OOB or read wrong values. Use per-row buffer sized [block_M].

-            sinks = T.alloc_fragment([heads], dtype)
+            sinks = T.alloc_fragment([block_M], dtype)
             T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
             T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
             T.fill(acc_o, 0)
             T.fill(logsum, 0)
             T.fill(scores_max, -T.infinity(accum_dtype))
             for i in T.Parallel(block_M):
                 sinks[i] = Sinks[by]
♻️ Duplicate comments (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)

388-396: Cast dk/dv grads to input dtype before returning.

Gradients for K/V are accumulated in float32 and returned as float32; Autograd expects them to match inputs.

         kernel(q, k, v, do, lse, delta, dq, dk, dv)
         dq = kernel_post(dq)
 
         kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
         dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
-        return dq, dk, dv, dsinks, None, None
+        # Match gradient dtypes with inputs
+        dk = dk.to(k.dtype)
+        dv = dv.to(v.dtype)
+        return dq, dk, dv, dsinks, None, None
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

389-391: Cast dsinks to the sinks dtype before returning.

Kernel produces float32 accumulation; return should match sinks.dtype.

-        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
-        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
+        kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
+        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype)
🧹 Nitpick comments (8)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (3)

345-346: Use Optional[int] for Python 3.8 compatibility.

Replace PEP 604 unions with typing.Optional to support py38 (consistent with other files).

-def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
+def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
-def main(
+def main(
     batch: int = 1,
     heads: int = 32,
     seq_q: int = 256,
     seq_kv: int = 256,
     dim: int = 128,
     groups: int = 8,
-    window_size: int | None = None,
+    window_size: Optional[int] = None,
     dtype: str = "float16",
     tune: bool = False,
 ):

[Based on learnings]

Also applies to: 396-398


439-444: Loosen tolerances for bfloat16 to avoid false negatives.

bf16 needs higher tolerances; mirror other examples.

-    torch.testing.assert_close(
-        kernel(Q, K, V, sinks),
-        ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
-        rtol=1e-2,
-        atol=1e-2)
+    torch.testing.assert_close(
+        kernel(Q, K, V, sinks),
+        ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
+        rtol=2e-2 if dtype == "bfloat16" else 1e-2,
+        atol=2e-2 if dtype == "bfloat16" else 1e-2)
-    if torch.allclose(
-            triton_program(Q, K, V, sinks, window_size),
-            ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
-            rtol=1e-2,
-            atol=1e-2):
+    if torch.allclose(
+            triton_program(Q, K, V, sinks, window_size),
+            ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
+            rtol=2e-2 if dtype == "bfloat16" else 1e-2,
+            atol=2e-2 if dtype == "bfloat16" else 1e-2):

Also applies to: 446-454


481-483: Restrict --dtype to valid choices.

Add argparse choices to prevent KeyError on invalid input.

-parser.add_argument(
-    '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
+parser.add_argument(
+    '--dtype',
+    type=str,
+    choices=["float16", "bfloat16"],
+    default="float16",
+    help="dtype")
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)

325-351: Consider accumulating dsinks in float32 but return in input dtype.

Currently dsinks kernel writes dtype, but intermediate accumulators are float; casting on return is safer/consistent with other paths.

If you keep dsinks as accum_dtype inside the kernel, return-time cast keeps API consistent:

-        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
+        dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1).to(sinks.dtype)

456-458: Use Optional[int] for Python 3.8 compatibility.

Replace PEP 604 unions.

-def main(BATCH: int = 1,
-         H: int = 8,
-         N_CTX: int = 512,
-         D_HEAD: int = 64,
-         groups: int = 2,
-         window_size: int | None = None,
-         dtype: str = "float16"):
+def main(BATCH: int = 1,
+         H: int = 8,
+         N_CTX: int = 512,
+         D_HEAD: int = 64,
+         groups: int = 2,
+         window_size: Optional[int] = None,
+         dtype: str = "float16"):

536-537: Restrict --dtype to valid choices.

Add argparse choices to guard against invalid inputs.

-parser.add_argument(
-        '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
+parser.add_argument(
+        '--dtype',
+        type=str,
+        choices=["float16", "bfloat16"],
+        default="float16",
+        help="dtype")
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)

451-452: Use Optional[int] for Python 3.8 compatibility.

Replace PEP 604 union.

-def main(BATCH: int = 1,
+def main(BATCH: int = 1,
          H: int = 1,
          N_CTX: int = 512,
          D_HEAD: int = 128,
-         window_size: int | None = None,
+         window_size: Optional[int] = None,
          dtype: str = "float16"):

529-530: Restrict --dtype to valid choices.

Add argparse choices.

-parser.add_argument(
-        '--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
+parser.add_argument(
+        '--dtype',
+        type=str,
+        choices=["float16", "bfloat16"],
+        default="float16",
+        help="dtype")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e236a0 and c8a376c.

📒 Files selected for processing (9)
  • examples/amd/example_amd_flash_attn_bwd.py (1 hunks)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (16 hunks)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (10 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (16 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py (8 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (9 hunks)
  • examples/flash_attention/example_gqa_bwd.py (2 hunks)
  • examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2 hunks)
  • examples/flash_attention/example_mha_bwd_bhsd.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py
🧰 Additional context used
🧬 Code graph analysis (6)
examples/flash_attention/example_gqa_bwd.py (1)
tilelang/language/atomic.py (1)
  • atomic_add (116-228)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1)
tilelang/language/atomic.py (1)
  • atomic_add (116-228)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
  • gen_inputs (239-250)
  • flashattn (26-187)
  • ref_program (191-236)
  • main (128-185)
  • main (253-312)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (5)
  • gen_inputs (365-376)
  • flashattn (30-199)
  • ref_program (204-249)
  • main (134-197)
  • main (379-446)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (6)
  • flashattn (34-206)
  • gen_inputs (374-386)
  • ref_program (211-255)
  • triton_program (345-371)
  • main (141-204)
  • main (389-465)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (9)
  • get_bwd_configs (11-19)
  • flashattn_fwd (28-139)
  • flashattn_bwd_preprocess (148-173)
  • flashattn_bwd_postprocess (188-205)
  • flashattn_bwd (213-322)
  • flashattn_bwd_dsink (326-351)
  • main (451-520)
  • ref_program (404-448)
  • backward (375-396)
tilelang/jit/__init__.py (1)
  • jit (240-313)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (11)
  • get_bwd_configs (11-19)
  • flashattn_fwd (28-136)
  • flashattn_bwd_preprocess (145-170)
  • flashattn_bwd_postprocess (185-202)
  • flashattn_bwd (210-322)
  • flash_bwd (233-320)
  • flashattn_bwd_dsink (326-351)
  • maybe_contiguous (371-374)
  • ref_program (399-444)
  • backward (367-391)
  • main (447-514)
🪛 Ruff (0.13.3)
examples/attention_sink/example_mha_sink_bwd_bhsd.py

19-19: Avoid specifying long messages outside the exception class

(TRY003)

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

19-19: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-metal

Comment on lines 49 to 53
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
dtype = "float16"
accum_dtype = "float"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Guard against invalid groups configuration.

Add a divisibility check to avoid silent floor when mapping heads to key/value groups.

     head_kv = heads // groups
+    assert heads % groups == 0, "heads must be divisible by groups"
     q_shape = [batch, heads, seq_len, dim]
     kv_shape = [batch, head_kv, seq_len, dim]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
dtype = "float16"
accum_dtype = "float"
head_kv = heads // groups
assert heads % groups == 0, "heads must be divisible by groups"
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float"
🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_bwd_bhsd.py around lines 49 to 53,
the code computes head_kv = heads // groups without checking divisibility which
can silently floor the result; add a guard that validates groups > 0 and heads %
groups == 0 and raise a clear exception (e.g., ValueError) with a message like
"heads must be divisible by groups" if the check fails so the caller is informed
instead of getting an incorrect head_kv; keep the subsequent q_shape/kv_shape
computation unchanged.

Comment on lines 57 to 61
head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against invalid groups configuration.

Add an explicit check to prevent silent floor division when heads is not divisible by groups.

     head_kv = heads // groups
+    assert heads % groups == 0, "heads must be divisible by groups"
     q_shape = [batch, heads, seq_q, dim]
     kv_shape = [batch, head_kv, seq_kv, dim]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
head_kv = heads // groups
assert heads % groups == 0, "heads must be divisible by groups"
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim]
accum_dtype = "float"
🤖 Prompt for AI Agents
In examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py around
lines 57 to 61, add a guard to validate the groups value before computing
head_kv: ensure groups is a positive non-zero integer and that heads is
divisible by groups (heads % groups == 0); if the check fails, raise a clear
ValueError (or assert) explaining that heads must be divisible by groups to
avoid silent floor-division. This prevents incorrect head_kv computation and
surfaces configuration errors early.

Comment on lines 247 to +248
for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reinstate seq_len guard before atomic_add into dQ (atomic kernel).

With the guard removed, the final tile on sequences where seq_len % block_N ≠ 0 now updates indices beyond dQ’s extent (e.g., indices ≥ seq_len). Please add back the predicate (if k * block_N + i < seq_len) or otherwise clamp the loop so the atomic write never targets rows outside the buffer.

🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 247
to 248, the atomic_add into dQ lacks a seq_len guard and can write past dQ when
seq_len % block_N != 0; restore a predicate so the atomic write only happens for
valid rows (e.g., wrap the T.atomic_add call in a conditional checking if k *
block_N + i < seq_len, or clamp the loop bounds accordingly) to ensure no atomic
writes target indices >= seq_len.

for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not drop the tail-bound check in the split kernel either.

The split-path loop still iterates over the full block_N, so without the previous conditional it will issue atomic_add on out-of-range indices for partial tail blocks. Please restore the k * block_N + i < seq_len guard (or equivalent bound enforcement) around this atomic update too.

🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines
363-364, the split-path loop issues atomic_add for full block_N even on tail
blocks, causing out-of-range writes; restore a guard checking k * block_N + i <
seq_len (or equivalent bounds enforcement) around the T.atomic_add so the atomic
update only executes for valid sequence indices, ensuring the loop can still
iterate block_N elements but skips/guards updates for tail elements.

Comment on lines 237 to +238
for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reintroduce the seq_len guard before the dQ atomic_add.

Removing the bounds check means tail tiles now write past the end of dQ whenever seq_len is not an exact multiple of block_N. For example, with seq_len=130 and block_N=64, the last iteration issues atomic_add on indices [128..191] even though only [128,129] exist, which is undefined behaviour on GPU. Please restore the if k * block_N + i < seq_len guard (or clamp the loop bounds) around the atomic update so we never address past the allocated extent.

🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd.py around lines 237-238, the
atomic_add to dQ can write past the allocated sequence length for tail tiles;
restore a guard so we never index beyond seq_len. Wrap the T.atomic_add(dQ[...],
...) call with a conditional check if k * block_N + i < seq_len (or adjust the
loop bounds to clamp i so indices < seq_len), ensuring the atomic update only
executes for valid positions.

Comment on lines 231 to +232
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Prevent tail-tile writes past seq_len in dQ.

This kernel also iterates a full block_N, so without the original if k * block_N + i < seq_len guard the atomic add writes beyond the valid sequence rows for trailing tiles (e.g., seq_len=130, block_N=64). Please reinstate the boundary predicate (or adjust loop bounds) before performing atomic_add.

🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd_bhsd.py around lines 231-232, the
atomic_add into dQ iterates the full block_N and can write past seq_len for
trailing tiles; reinstate the original boundary predicate (e.g., check if k *
block_N + i < seq_len) or change the loop bounds so i only ranges over valid
rows before calling T.atomic_add, ensuring the guard is evaluated per iteration
and prevents writes beyond the sequence length.

LeiWang1999
LeiWang1999 previously approved these changes Oct 10, 2025
@LeiWang1999 LeiWang1999 merged commit 7cd0da9 into tile-ai:main Oct 10, 2025
6 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants