Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Oct 1, 2025

For #917 and #902.


This pull request introduces improvements and bug fixes to the FlashAttention examples, primarily focused on the backward pass implementation for grouped query attention (GQA) and pipelined WMMA kernels. The most significant changes include correcting buffer shapes for shared memory usage, updating the accumulation and reduction logic for gradients, and adding a new pipelined WMMA implementation for GQA backward pass. These updates improve correctness, performance, and maintain consistency across the codebase.

Grouped Query Attention (GQA) Backward Pass Improvements:

  • Corrected the shapes of dK and dV gradient buffers in the backward kernel and their initialization, ensuring proper accumulation and reduction across groups. [1] [2] [3]
  • Updated the shared buffer allocation for dv_shared and dk_shared to use block_M instead of block_N, fixing a bug in memory access and improving correctness.
  • Changed the reduction logic: replaced atomic adds with explicit copies and sum reductions for dK and dV, matching the new buffer shapes and improving performance.
  • Added a success message after correctness checks in the main function for better user feedback.

New Pipelined WMMA GQA Backward Example:

  • Added example_gqa_bwd_wgmma_pipelined.py, a new example customized for Hopper implementing the GQA backward pass using pipelined WGMMA kernels. This includes forward, backward, and reference implementations, as well as benchmarking and correctness checks.

Forward Pass Buffer Shape Fixes:

  • Fixed the shape of the shared buffer V_shared in the macro MMA1 from [block_M, dim] to [block_N, dim] in multiple files, ensuring consistency and correctness in the forward pass for both GQA and MHA examples. [1] [2] [3] [4]

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 1, 2025

Walkthrough

Introduces grouped gradient shapes in GQA backward and revises accumulation/reduction logic. Adds a new WGMMA-pipelined attention implementation with forward, backward, preprocess/postprocess kernels, an autograd wrapper, and a CPU reference plus CLI/test harness. Several forward examples switch V_shared buffer from [block_M, dim] to [block_N, dim]. Tests add a new case and stricter GPU capability decorators.

Changes

Cohort / File(s) Summary
GQA backward: grouped gradients
examples/flash_attention/example_gqa_bwd.py
Backward now uses grouped shapes dk_shape/dv_shape; kernel signature updated; shared-buffer dims switched for dk/dv allocations; atomic adds replaced with per-group copies and later reduction; shape_k/shape_v extended with groups; final dk/dv reduced via sum over group dimension; added status print.
New WGMMA-pipelined GQA attn (fwd+bwd)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
Adds tilelang.jit kernels: flashattn_fwd, bwd_preprocess, bwd_postprocess, flashattn_bwd; provides make_dq_layout; autograd Function with forward/backward; CPU reference ref_program; main entry and benchmarking; supports groups and causal.
FWD examples: V_shared shape adjust
examples/flash_attention/example_gqa_fwd_bshd.py, examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py, examples/flash_attention/example_mha_fwd_bhsd.py, examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py, examples/flash_attention/example_mha_fwd_bshd.py, examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
Change MMA macro parameter V_shared from [block_M, dim] to [block_N, dim]; no other control-flow changes.
Tests
examples/flash_attention/test_example_flash_attention.py
Imports new pipelined example; adds test_example_gqa_bwd_wgmma_pipelined; adds requires_cuda_compute_version_ge(9, 0) to selected tests.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Py as PyTorch Caller
  participant AT as attention (autograd)
  participant FWD as flashattn_fwd (JIT)
  participant PP as flashattn_bwd_preprocess (JIT)
  participant BWD as flashattn_bwd (JIT)
  participant PO as flashattn_bwd_postprocess (JIT)
  participant Ref as ref_program (CPU)

  Py->>AT: forward(q,k,v, causal, groups)
  AT->>FWD: launch(Q,K,V, is_causal, blocks)
  FWD-->>AT: O, lse
  note right of AT: Save Q,K,V,O,lse for backward
  AT-->>Py: O

  Py->>AT: backward(dO)
  AT->>PP: preprocess(O, dO) -> Delta
  PP-->>AT: Delta
  AT->>BWD: launch(Q,K,V,dO,lse,Delta, groups)
  BWD-->>AT: dQ_partial, dK_grouped, dV_grouped
  AT->>PO: postprocess(dQ_partial)
  PO-->>AT: dQ
  note right of AT: Reduce dK/dV over groups
  AT-->>Py: dQ, sum(dK_grouped, dim=0), sum(dV_grouped, dim=0)

  Py-->>Ref: (optional) validate Q,K,V
  Ref-->>Py: O_ref, grads_ref
Loading
sequenceDiagram
  autonumber
  participant KB as Kernel (old bwd)
  participant NG as New Kernel (grouped bwd)

  rect rgb(245,245,255)
  note over KB: Old path
  KB->>KB: Atomic add to dK/dV [B,N,HKV,D]
  end

  rect rgb(245,255,245)
  note over NG: New path
  NG->>NG: Write per-group dk/dv [G,B,N,HKV,D]
  NG->>NG: Copy/accumulate within group fragments
  NG-->>NG: Host-side reduce sum over G -> dK/dV
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

I thump my paw: new kernels bloom,
Groups align, gradients zoom.
V now dances by block_N’s tune,
WGMMA hums a pipelined rune.
Tests nibble greens of CUDA nine—
All checks passed; the carrots shine. 🥕✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly summarizes the primary change—enhancing and adding new grouped query attention (GQA) backward examples targeted at the Hopper architecture—without including file lists or unnecessary detail, making it clear to reviewers the main purpose of the pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Oct 1, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@LeiWang1999
Copy link
Member

cc the following comments at #917 (comment) .

I don’t think the issue lies with atomicAdd. This approach can also be efficient, and it’s what FA uses. While we can introduce a split + reduce template, we should not remove the original atomic template.

I keep both points as following:

  1. atomicAdd can still be fast.
  2. Suggestion to keep the existing atomic template alongside any new split+reduce template.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2)

116-119: Consider renaming parameter 'l' for clarity.

The layout lambda uses l for the sequence dimension. To improve readability, consider renaming to seq or seq_idx.

Apply this diff to improve clarity:

-    return T.Layout(dQ.shape,
-                    lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
+    return T.Layout(dQ.shape,
+                    lambda b, seq, h, d: [b, seq // 8, h, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])

329-386: LGTM! Main function correctly validates and benchmarks.

The test harness properly:

  • Allocates GQA-shaped tensors.
  • Runs forward and backward for both kernel and reference implementations.
  • Validates outputs and gradients with appropriate tolerances.
  • Prints a success message after checks pass.
  • Benchmarks both implementations.

Optional: Consider renaming 'O' for clarity.

While 'O' is standard in attention literature, renaming to output or attn_output may improve readability per linting conventions.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f09e91e and e56613e.

📒 Files selected for processing (9)
  • examples/flash_attention/example_gqa_bwd.py (6 hunks)
  • examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/example_gqa_fwd_bshd.py (1 hunks)
  • examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/example_mha_fwd_bhsd.py (1 hunks)
  • examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/example_mha_fwd_bshd.py (1 hunks)
  • examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/test_example_flash_attention.py (5 hunks)
👮 Files not reviewed due to content moderation or server errors (3)
  • examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
  • examples/flash_attention/test_example_flash_attention.py
  • examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
🧰 Additional context used
🧬 Code graph analysis (9)
examples/flash_attention/example_gqa_fwd_bshd.py (1)
tilelang/language/proxy.py (1)
  • SharedBuffer (263-264)
examples/flash_attention/example_mha_fwd_bshd.py (1)
tilelang/language/proxy.py (1)
  • SharedBuffer (263-264)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
tilelang/language/proxy.py (1)
  • SharedBuffer (263-264)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
tilelang/language/proxy.py (1)
  • SharedBuffer (263-264)
examples/flash_attention/example_gqa_bwd.py (3)
src/tl_templates/cuda/reduce.h (1)
  • T (75-147)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/copy.py (1)
  • copy (84-152)
examples/flash_attention/test_example_flash_attention.py (1)
tilelang/testing/__init__.py (1)
  • requires_cuda_compute_version_ge (95-96)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
tilelang/language/proxy.py (1)
  • SharedBuffer (263-264)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/language/proxy.py (1)
  • SharedBuffer (263-264)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (9)
tilelang/jit/__init__.py (2)
  • jit (240-313)
  • compile (33-86)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/__init__.py (1)
  • annotate_layout (106-144)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/fill.py (2)
  • fill (9-21)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
tilelang/language/reduce.py (2)
  • reduce_max (50-68)
  • reduce_sum (87-109)
🪛 Ruff (0.13.2)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py

95-95: Ambiguous variable name: O

(E741)


119-119: Ambiguous variable name: l

(E741)


355-355: Ambiguous variable name: O

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: bot-task
  • GitHub Check: format-check
🔇 Additional comments (40)
examples/flash_attention/example_mha_fwd_bhsd.py (1)

64-64: Critical correctness fix for V_shared buffer dimensions.

The shape change from [block_M, dim] to [block_N, dim] ensures dimensional consistency:

  • Line 71 copies a [block_N, dim] slice of V into V_shared
  • Line 72 performs GEMM: acc_s_cast [block_M, block_N] @ V_shared [block_N, dim] → acc_o [block_M, dim]

The old [block_M, dim] shape would cause failures when block_M ≠ block_N and was silently incorrect even when dimensions happened to match.

examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)

64-64: V_shared shape correction aligns with data flow.

The buffer shape change from [block_M, dim] to [block_N, dim] correctly matches the V copy at line 71 (V[..., k * block_N:(k + 1) * block_N, :]) and the GEMM operation at line 72 ([block_M, block_N] @ [block_N, dim] → [block_M, dim]).

examples/flash_attention/example_mha_fwd_bshd.py (1)

58-58: V_shared shape correction aligns with data flow.

The buffer shape change from [block_M, dim] to [block_N, dim] correctly matches the V copy at line 65 and the GEMM dimensions at line 66.

examples/flash_attention/example_gqa_fwd_bshd.py (1)

105-105: V_shared shape correction aligns with GQA data flow.

The buffer shape change from [block_M, dim] to [block_N, dim] correctly matches the V copy at line 112 (with GQA indexing) and the GEMM dimensions at line 113.

examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)

72-72: V_shared shape correction aligns with pipelined GQA data flow.

The buffer shape change from [block_M, dim] to [block_N, dim] correctly matches the V copy at line 79 and the GEMM dimensions at line 80.

examples/flash_attention/example_gqa_bwd.py (5)

157-158: Grouped gradient accumulation for dK and dV.

Introducing per-group gradient shapes ([groups, batch, seq_len, head_kv, dim]) enables race-free parallel accumulation across groups, with a final sum reduction (line 286). This correctly addresses concurrent writes in GQA backward.


189-190: Shared buffer shapes corrected to match fragment dimensions.

The dv_shared and dk_shared buffers are updated from [block_N, dim] to [block_M, dim] to correctly match the dv and dk fragment shapes (lines 186-187) and the K/V copy dimensions (lines 199-200). This ensures proper data flow for the gradient write-back.


235-238: Gradient write-back using grouped indexing without atomics.

The copy operations correctly write dK and dV gradients to per-group locations using bx % groups (group index) and bx // groups (KV head index). This avoids race conditions while preserving correctness, as the final sum reduction (line 286) aggregates all contributions.


279-280: Grouped gradient allocation and reduction.

The backward method correctly allocates dk/dv with grouped shapes, passes them to the kernel for race-free accumulation, then reduces along the group dimension (line 286) to produce final gradients matching the original K/V shapes.

Also applies to: 282-283, 286-286


360-360: Success message improves user feedback.

The added confirmation message provides clear feedback after correctness validation completes.

examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (3)

58-58: LGTM! Buffer shape now matches copy extent.

The change from [block_M, dim] to [block_N, dim] correctly aligns the V_shared buffer shape with the actual data slice copied at line 65, which is block_N rows.


58-58: LGTM! V_shared shape corrected.

The parameter shape now matches the actual allocation at line 116 (T.alloc_shared([block_N, dim], dtype)) and the copy operation at line 65, which loads block_N rows from V.


58-58: LGTM! Correct buffer shape alignment.

The V_shared shape change from [block_M, dim] to [block_N, dim] fixes a signature inconsistency. The gemm at line 66 multiplies acc_s_cast[block_M, block_N] by V_shared, requiring V_shared to be [block_N, dim] for the operation [M, N] @ [N, dim] → [M, dim]. This aligns with the allocation at line 116 and the copy operation at line 65 that loads block_N rows.

examples/flash_attention/test_example_flash_attention.py (6)

4-4: LGTM! New test correctly gated for Hopper.

The new import and test function for example_gqa_bwd_wgmma_pipelined follow the established pattern and correctly require CUDA compute capability ≥ 9.0 (Hopper) for WGMMA support.

Also applies to: 22-25


45-45: LGTM! Hopper requirement correctly enforced.

Adding requires_cuda_compute_version_ge(9, 0) to all WGMMA-pipelined tests ensures they only run on Hopper (sm_90+) GPUs that support the required instructions.

Also applies to: 56-56, 67-67


4-25: LGTM! New test and import for WGMMA pipelined backward pass.

The new test correctly imports and invokes the new example_gqa_bwd_wgmma_pipelined module, and the dual decorators properly gate execution to CUDA devices with compute capability 9.0+.


45-45: LGTM! Hopper compute requirement added.

The requires_cuda_compute_version_ge(9, 0) decorators correctly restrict WGMMA-pipelined forward tests to hardware supporting Hopper instructions.

Also applies to: 56-56, 67-67


4-4: LGTM! New test properly gated for Hopper.

The new test for example_gqa_bwd_wgmma_pipelined correctly uses both @requires_cuda and @requires_cuda_compute_version_ge(9, 0) decorators, ensuring it only runs on Hopper (SM_90+) GPUs that support WGMMA instructions.

Also applies to: 22-25


45-45: LGTM! Consistent Hopper gating for WGMMA tests.

Adding @requires_cuda_compute_version_ge(9, 0) to all wgmma_pipelined tests is correct. WGMMA is a Hopper-specific instruction set, and these guards prevent test failures on pre-Hopper architectures.

Also applies to: 56-56, 67-67

examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (21)

8-80: LGTM! Forward pass correctly implements grouped query attention.

The forward kernel correctly handles GQA with head_kv = heads // groups and maps Q heads to K/V heads via by // groups indexing. Buffer shapes, access patterns, and output handling are consistent.


83-113: LGTM! Preprocess kernel correctly computes Delta.

The backward preprocessing kernel correctly computes Delta = sum(O * dO) along the head dimension, which is required for the FlashAttention backward pass.


116-119: LGTM! Layout correctly handles atomic accumulation pattern.

The make_dq_layout function correctly reorders dQ to match 8×8 GEMM fragment layout, which is necessary for atomic accumulation in the backward pass.


122-144: LGTM! Postprocess kernel handles layout and type conversion.

The backward postprocessing kernel correctly annotates the layout and converts dQ from the accumulation dtype (float32) to output dtype (float16).


147-251: LGTM! Backward kernel correctly implements grouped gradient accumulation.

The backward kernel correctly handles grouped query attention:

  • dK and dV shapes include a leading groups dimension (lines 157-158) for per-group accumulation, matching the PR objectives.
  • K/V are read with bx // groups (lines 199-200), consistent with the forward pass.
  • dK/dV are written with bx % groups (lines 247, 249) for correct group indexing.
  • dQ uses atomic accumulation (line 244) as multiple blocks contribute to the same Q positions.

254-298: LGTM! Autograd wrapper correctly orchestrates forward and backward passes.

The _attention class correctly implements the PyTorch autograd interface:

  • Forward pass saves necessary tensors and returns output.
  • Backward pass runs preprocess → backward → postprocess pipeline, allocates grouped gradient buffers, and reduces dK/dV over the groups dimension (line 297).
  • Contiguity normalization ensures efficient kernel execution.

304-326: LGTM! Reference implementation correctly handles grouped query attention.

The CPU reference correctly implements grouped query attention by repeating K/V with repeat_interleave and computing standard attention with optional causal masking. This provides a reliable ground truth for validation.


329-387: LGTM! Main function correctly validates and benchmarks the implementation.

The main function correctly:

  • Sets up test inputs with appropriate shapes for grouped query attention.
  • Runs both the TileLang implementation and CPU reference.
  • Validates forward and backward outputs with appropriate tolerances.
  • Prints a success message after correctness checks pass (line 371), as noted in the PR objectives.
  • Benchmarks both implementations for performance comparison.

1-6: LGTM! Imports are appropriate.

All necessary dependencies are imported without redundancy.


8-80: LGTM! Forward pass implementation is correct.

The GQA-aware forward kernel correctly indexes K/V using by // groups, allocates buffers with appropriate shapes, and implements the pipelined Flash Attention algorithm with proper causal masking.


83-113: LGTM! Preprocessing kernel correctly computes Delta.

The preprocess step correctly computes the row-wise dot product Delta = sum(O * dO, dim=-1) required for the backward pass.


122-144: LGTM! Postprocessing correctly converts dQ layout and dtype.

The postprocess kernel applies the custom layout to dQ and converts from float32 accumulation to float16 output.


147-245: LGTM! Backward kernel logic is correct.

The GQA-aware backward kernel correctly:

  • Defines grouped gradient shapes with a leading groups dimension for accumulation.
  • Indexes K/V using bx // groups for GQA.
  • Uses bx % groups to distribute gradients across groups.
  • Implements the Flash Attention backward algorithm with proper causal masking and atomic updates for dQ.

254-298: LGTM! Autograd wrapper correctly integrates forward and backward passes.

The _attention class correctly:

  • Calls the forward kernel and saves necessary tensors.
  • Ensures contiguity for backward inputs.
  • Allocates grouped gradient buffers matching the kernel's output shapes.
  • Sums gradients over the groups dimension (line 297) to produce final dK and dV.

301-301: LGTM! Standard autograd Function alias.


304-326: LGTM! Reference implementation correctly handles GQA.

The CPU reference properly expands K and V heads using repeat_interleave to match Q's head count, and validates shape consistency with assertions.


8-80: LGTM! Forward kernel correctly implements grouped attention.

The forward kernel properly handles grouped query attention:

  • Q has heads heads, K/V have head_kv = heads // groups heads (line 14).
  • Kernel grid iterates over Q heads, mapping to K/V with by // groups (lines 51, 59).
  • Buffer shapes are consistent: V_shared is [block_N, dim_v] matching the corrected shape pattern from the PR.

83-113: LGTM! Backward preprocess correctly computes Delta.

The preprocessing kernel computes Delta[i] = sum_j(O[i,j] * dO[i,j]), which is standard for FlashAttention backward. The implementation properly accumulates across blocks and reduces along dimension 1.


116-119: Layout transformation for atomic accumulation.

The make_dq_layout function creates a swizzled layout for dQ that aligns with 8x8 gemm fragment boundaries. This is necessary because atomic operations can't be vectorized, so the layout must match the fragment structure for efficient atomic adds (line 244).


122-144: LGTM! Postprocess converts accumulated gradients to output format.

The postprocess kernel converts dQ from float32 (accumulation dtype) to float16 (output dtype) while applying the layout transformation. This is the final step after atomic accumulation completes.


304-387: LGTM! Comprehensive testing with CPU reference.

The reference implementation correctly handles grouped attention by repeating K/V heads (lines 315-316). The main function validates both forward and backward passes with appropriate tolerances for float16, and includes the success message mentioned in the PR objectives (line 371).

Comment on lines +147 to +251
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)

T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})

T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1)

T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.wait_wgmma(0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)

T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)

for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)

T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])

T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])

return flash_bwd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify gradient copy consistency at line 249.

The backward kernel implements grouped gradient accumulation correctly by using a leading groups dimension in dK/dV shapes (lines 157-158) and indexing with bx % groups (lines 247, 249). This avoids atomics for K/V gradients.

However, line 249 copies dk directly while line 247 copies dv_shared (after line 246 converts dv to dv_shared). Line 248 also converts dk to dk_shared. Should line 249 copy dk_shared instead of dk for consistency?


🏁 Script executed:

#!/bin/bash
# Check if there's a pattern in other backward implementations
rg -n -C 3 'T\.copy\(dk' examples/flash_attention/

Length of output: 4094


Correct the K-gradient copy to use dk_shared.

In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (line 249), change:

-            T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
+            T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])

This matches the pattern in other kernels and ensures the shared buffer is written out correctly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.wait_wgmma(0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
return flash_bwd
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 147
to 251, the final write-back for the K gradients incorrectly copies from the
fragment 'dk' instead of the shared buffer 'dk_shared'; change the last T.copy
call so you first copy the fragment into dk_shared (keep T.copy(dk, dk_shared))
and then write dk_shared into dK (replace T.copy(dk, dK[...]) with
T.copy(dk_shared, dK[...])) ensuring the same index expression used for dV is
applied to dK.

Comment on lines +246 to +249
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Inconsistency in dK copy operation.

Line 247 correctly copies from dv_shared, but line 249 copies from dk (fragment) instead of dk_shared. This is inconsistent with the dV copy pattern and may be a bug.

Apply this diff to fix the inconsistency:

 T.copy(dv, dv_shared)
 T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
 T.copy(dk, dk_shared)
-T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
+T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 246
to 249, the dK copy is inconsistent: it currently copies from the fragment
variable `dk` instead of the shared buffer `dk_shared`, unlike the dv/dV
pattern. Change the T.copy source on the dK line to use dk_shared and keep the
same destination indexing as the dV copy (i.e., write into dK[bx % groups, bz,
by * block_M:(by + 1) * block_M, bx // groups, :]) so the shared-buffer value is
stored into the global dK tensor.

return flash_bwd


@torch.compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify torch.compile decorator on class.

@torch.compile is typically applied to functions or methods, not class definitions. Applying it to _attention class may not have the intended effect. If the goal is to compile the forward/backward methods, consider applying @torch.compile to individual methods or removing it if compilation happens via the attention = _attention.apply call at runtime.


🌐 Web query:

Can torch.compile decorator be applied to a torch.autograd.Function class in PyTorch 2.8?

💡 Result:

Short answer: No — you should not decorate a torch.autograd.Function class with @torch.compile in PyTorch 2.8. torch.compile targets callables (plain Python functions or nn.Module.forward) and the compiled-autograd machinery rewrites/handles custom autograd.Functions at compile time; you should compile the caller (or the .apply wrapper), not the Function subclass itself. See PyTorch docs and the known limitations/bugs (e.g. tensor-subclass + save_for_backward) for details. [1][2][3]

Sources:

  • torch.compile API (callable/module compilation). [1]
  • Compiled Autograd / how torch.compile handles custom autograd during forward/backward. [2]
  • GitHub issue noting unsupported combos with torch.autograd.Function and torch.compile. [3]

Remove the @torch.compile decorator from the custom autograd Function class (line 254).

torch.compile only supports plain functions or nn.Module.forward, not torch.autograd.Function subclasses; apply compilation to the .apply wrapper or the caller instead.

🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around line 254,
the @torch.compile decorator is incorrectly applied to a custom
torch.autograd.Function subclass; remove the decorator from the class definition
and instead apply torch.compile to either the Function.apply wrapper or to the
caller (the function or nn.Module.forward that invokes .apply). Update the file
by deleting the @torch.compile line above the autograd.Function class and, if
compilation is desired, wrap the call-site or a thin wrapper function that calls
YourFunction.apply with torch.compile.

Comment on lines +389 to +399
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix argparse boolean flag handling.

Line 396 uses type=bool, which doesn't work as expected with argparse. Any non-empty string (including "False") will be converted to True.

Apply this diff to fix the boolean flag:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', action='store_true', help='Causal flag')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk',type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)

parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix argparse boolean flag.

Using type=bool for the --causal flag is incorrect. In argparse, type=bool will convert any non-empty string (including "False") to True. Use action='store_true' instead, consistent with other examples in the codebase (e.g., example_mha_fwd_bshd_wgmma_pipelined.py line 220).

Apply this diff:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', action='store_true', help='Causal flag')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument('--causal', action='store_true', help='Causal flag')
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around line 396,
the argparse declaration uses type=bool which treats any non-empty string as
True; replace that argument with a boolean flag using action='store_true'
(remove type and default), e.g. change to use action='store_true' with the same
help text so --causal sets True and absence leaves it False, matching the other
examples.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants