-
Notifications
You must be signed in to change notification settings - Fork 290
[Enhancement] Enhance and add new GQA backward examples for Hopper #930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughIntroduces grouped gradient shapes in GQA backward and revises accumulation/reduction logic. Adds a new WGMMA-pipelined attention implementation with forward, backward, preprocess/postprocess kernels, an autograd wrapper, and a CPU reference plus CLI/test harness. Several forward examples switch V_shared buffer from [block_M, dim] to [block_N, dim]. Tests add a new case and stricter GPU capability decorators. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Py as PyTorch Caller
participant AT as attention (autograd)
participant FWD as flashattn_fwd (JIT)
participant PP as flashattn_bwd_preprocess (JIT)
participant BWD as flashattn_bwd (JIT)
participant PO as flashattn_bwd_postprocess (JIT)
participant Ref as ref_program (CPU)
Py->>AT: forward(q,k,v, causal, groups)
AT->>FWD: launch(Q,K,V, is_causal, blocks)
FWD-->>AT: O, lse
note right of AT: Save Q,K,V,O,lse for backward
AT-->>Py: O
Py->>AT: backward(dO)
AT->>PP: preprocess(O, dO) -> Delta
PP-->>AT: Delta
AT->>BWD: launch(Q,K,V,dO,lse,Delta, groups)
BWD-->>AT: dQ_partial, dK_grouped, dV_grouped
AT->>PO: postprocess(dQ_partial)
PO-->>AT: dQ
note right of AT: Reduce dK/dV over groups
AT-->>Py: dQ, sum(dK_grouped, dim=0), sum(dV_grouped, dim=0)
Py-->>Ref: (optional) validate Q,K,V
Ref-->>Py: O_ref, grads_ref
sequenceDiagram
autonumber
participant KB as Kernel (old bwd)
participant NG as New Kernel (grouped bwd)
rect rgb(245,245,255)
note over KB: Old path
KB->>KB: Atomic add to dK/dV [B,N,HKV,D]
end
rect rgb(245,255,245)
note over NG: New path
NG->>NG: Write per-group dk/dv [G,B,N,HKV,D]
NG->>NG: Copy/accumulate within group fragments
NG-->>NG: Host-side reduce sum over G -> dK/dV
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
cc the following comments at #917 (comment) . I don’t think the issue lies with I keep both points as following:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (2)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2)
116-119: Consider renaming parameter 'l' for clarity.The layout lambda uses
lfor the sequence dimension. To improve readability, consider renaming toseqorseq_idx.Apply this diff to improve clarity:
- return T.Layout(dQ.shape, - lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + return T.Layout(dQ.shape, + lambda b, seq, h, d: [b, seq // 8, h, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])
329-386: LGTM! Main function correctly validates and benchmarks.The test harness properly:
- Allocates GQA-shaped tensors.
- Runs forward and backward for both kernel and reference implementations.
- Validates outputs and gradients with appropriate tolerances.
- Prints a success message after checks pass.
- Benchmarks both implementations.
Optional: Consider renaming 'O' for clarity.
While 'O' is standard in attention literature, renaming to
outputorattn_outputmay improve readability per linting conventions.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/flash_attention/example_gqa_bwd.py(6 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py(1 hunks)examples/flash_attention/example_gqa_fwd_bshd.py(1 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py(1 hunks)examples/flash_attention/example_mha_fwd_bhsd.py(1 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py(1 hunks)examples/flash_attention/example_mha_fwd_bshd.py(1 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py(1 hunks)examples/flash_attention/test_example_flash_attention.py(5 hunks)
👮 Files not reviewed due to content moderation or server errors (3)
- examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
- examples/flash_attention/test_example_flash_attention.py
- examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
🧰 Additional context used
🧬 Code graph analysis (9)
examples/flash_attention/example_gqa_fwd_bshd.py (1)
tilelang/language/proxy.py (1)
SharedBuffer(263-264)
examples/flash_attention/example_mha_fwd_bshd.py (1)
tilelang/language/proxy.py (1)
SharedBuffer(263-264)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
tilelang/language/proxy.py (1)
SharedBuffer(263-264)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
tilelang/language/proxy.py (1)
SharedBuffer(263-264)
examples/flash_attention/example_gqa_bwd.py (3)
src/tl_templates/cuda/reduce.h (1)
T(75-147)tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/copy.py (1)
copy(84-152)
examples/flash_attention/test_example_flash_attention.py (1)
tilelang/testing/__init__.py (1)
requires_cuda_compute_version_ge(95-96)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
tilelang/language/proxy.py (1)
SharedBuffer(263-264)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/language/proxy.py (1)
SharedBuffer(263-264)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (9)
tilelang/jit/__init__.py (2)
jit(240-313)compile(33-86)tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/__init__.py (1)
annotate_layout(106-144)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/fill.py (2)
fill(9-21)clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/parallel.py (1)
Parallel(8-28)tilelang/language/reduce.py (2)
reduce_max(50-68)reduce_sum(87-109)
🪛 Ruff (0.13.2)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
95-95: Ambiguous variable name: O
(E741)
119-119: Ambiguous variable name: l
(E741)
355-355: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (40)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
64-64: Critical correctness fix for V_shared buffer dimensions.The shape change from
[block_M, dim]to[block_N, dim]ensures dimensional consistency:
- Line 71 copies a
[block_N, dim]slice of V into V_shared- Line 72 performs GEMM:
acc_s_cast [block_M, block_N] @ V_shared [block_N, dim] → acc_o [block_M, dim]The old
[block_M, dim]shape would cause failures whenblock_M ≠ block_Nand was silently incorrect even when dimensions happened to match.examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
64-64: V_shared shape correction aligns with data flow.The buffer shape change from
[block_M, dim]to[block_N, dim]correctly matches the V copy at line 71 (V[..., k * block_N:(k + 1) * block_N, :]) and the GEMM operation at line 72 ([block_M, block_N] @ [block_N, dim] → [block_M, dim]).examples/flash_attention/example_mha_fwd_bshd.py (1)
58-58: V_shared shape correction aligns with data flow.The buffer shape change from
[block_M, dim]to[block_N, dim]correctly matches the V copy at line 65 and the GEMM dimensions at line 66.examples/flash_attention/example_gqa_fwd_bshd.py (1)
105-105: V_shared shape correction aligns with GQA data flow.The buffer shape change from
[block_M, dim]to[block_N, dim]correctly matches the V copy at line 112 (with GQA indexing) and the GEMM dimensions at line 113.examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
72-72: V_shared shape correction aligns with pipelined GQA data flow.The buffer shape change from
[block_M, dim]to[block_N, dim]correctly matches the V copy at line 79 and the GEMM dimensions at line 80.examples/flash_attention/example_gqa_bwd.py (5)
157-158: Grouped gradient accumulation for dK and dV.Introducing per-group gradient shapes (
[groups, batch, seq_len, head_kv, dim]) enables race-free parallel accumulation across groups, with a final sum reduction (line 286). This correctly addresses concurrent writes in GQA backward.
189-190: Shared buffer shapes corrected to match fragment dimensions.The dv_shared and dk_shared buffers are updated from
[block_N, dim]to[block_M, dim]to correctly match the dv and dk fragment shapes (lines 186-187) and the K/V copy dimensions (lines 199-200). This ensures proper data flow for the gradient write-back.
235-238: Gradient write-back using grouped indexing without atomics.The copy operations correctly write dK and dV gradients to per-group locations using
bx % groups(group index) andbx // groups(KV head index). This avoids race conditions while preserving correctness, as the final sum reduction (line 286) aggregates all contributions.
279-280: Grouped gradient allocation and reduction.The backward method correctly allocates dk/dv with grouped shapes, passes them to the kernel for race-free accumulation, then reduces along the group dimension (line 286) to produce final gradients matching the original K/V shapes.
Also applies to: 282-283, 286-286
360-360: Success message improves user feedback.The added confirmation message provides clear feedback after correctness validation completes.
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (3)
58-58: LGTM! Buffer shape now matches copy extent.The change from
[block_M, dim]to[block_N, dim]correctly aligns the V_shared buffer shape with the actual data slice copied at line 65, which isblock_Nrows.
58-58: LGTM! V_shared shape corrected.The parameter shape now matches the actual allocation at line 116 (
T.alloc_shared([block_N, dim], dtype)) and the copy operation at line 65, which loadsblock_Nrows from V.
58-58: LGTM! Correct buffer shape alignment.The V_shared shape change from
[block_M, dim]to[block_N, dim]fixes a signature inconsistency. The gemm at line 66 multipliesacc_s_cast[block_M, block_N]byV_shared, requiring V_shared to be[block_N, dim]for the operation[M, N] @ [N, dim] → [M, dim]. This aligns with the allocation at line 116 and the copy operation at line 65 that loadsblock_Nrows.examples/flash_attention/test_example_flash_attention.py (6)
4-4: LGTM! New test correctly gated for Hopper.The new import and test function for
example_gqa_bwd_wgmma_pipelinedfollow the established pattern and correctly require CUDA compute capability ≥ 9.0 (Hopper) for WGMMA support.Also applies to: 22-25
45-45: LGTM! Hopper requirement correctly enforced.Adding
requires_cuda_compute_version_ge(9, 0)to all WGMMA-pipelined tests ensures they only run on Hopper (sm_90+) GPUs that support the required instructions.Also applies to: 56-56, 67-67
4-25: LGTM! New test and import for WGMMA pipelined backward pass.The new test correctly imports and invokes the new
example_gqa_bwd_wgmma_pipelinedmodule, and the dual decorators properly gate execution to CUDA devices with compute capability 9.0+.
45-45: LGTM! Hopper compute requirement added.The
requires_cuda_compute_version_ge(9, 0)decorators correctly restrict WGMMA-pipelined forward tests to hardware supporting Hopper instructions.Also applies to: 56-56, 67-67
4-4: LGTM! New test properly gated for Hopper.The new test for
example_gqa_bwd_wgmma_pipelinedcorrectly uses both@requires_cudaand@requires_cuda_compute_version_ge(9, 0)decorators, ensuring it only runs on Hopper (SM_90+) GPUs that support WGMMA instructions.Also applies to: 22-25
45-45: LGTM! Consistent Hopper gating for WGMMA tests.Adding
@requires_cuda_compute_version_ge(9, 0)to allwgmma_pipelinedtests is correct. WGMMA is a Hopper-specific instruction set, and these guards prevent test failures on pre-Hopper architectures.Also applies to: 56-56, 67-67
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (21)
8-80: LGTM! Forward pass correctly implements grouped query attention.The forward kernel correctly handles GQA with
head_kv = heads // groupsand maps Q heads to K/V heads viaby // groupsindexing. Buffer shapes, access patterns, and output handling are consistent.
83-113: LGTM! Preprocess kernel correctly computes Delta.The backward preprocessing kernel correctly computes
Delta = sum(O * dO)along the head dimension, which is required for the FlashAttention backward pass.
116-119: LGTM! Layout correctly handles atomic accumulation pattern.The
make_dq_layoutfunction correctly reorders dQ to match 8×8 GEMM fragment layout, which is necessary for atomic accumulation in the backward pass.
122-144: LGTM! Postprocess kernel handles layout and type conversion.The backward postprocessing kernel correctly annotates the layout and converts dQ from the accumulation dtype (float32) to output dtype (float16).
147-251: LGTM! Backward kernel correctly implements grouped gradient accumulation.The backward kernel correctly handles grouped query attention:
- dK and dV shapes include a leading
groupsdimension (lines 157-158) for per-group accumulation, matching the PR objectives.- K/V are read with
bx // groups(lines 199-200), consistent with the forward pass.- dK/dV are written with
bx % groups(lines 247, 249) for correct group indexing.- dQ uses atomic accumulation (line 244) as multiple blocks contribute to the same Q positions.
254-298: LGTM! Autograd wrapper correctly orchestrates forward and backward passes.The
_attentionclass correctly implements the PyTorch autograd interface:
- Forward pass saves necessary tensors and returns output.
- Backward pass runs preprocess → backward → postprocess pipeline, allocates grouped gradient buffers, and reduces dK/dV over the groups dimension (line 297).
- Contiguity normalization ensures efficient kernel execution.
304-326: LGTM! Reference implementation correctly handles grouped query attention.The CPU reference correctly implements grouped query attention by repeating K/V with
repeat_interleaveand computing standard attention with optional causal masking. This provides a reliable ground truth for validation.
329-387: LGTM! Main function correctly validates and benchmarks the implementation.The main function correctly:
- Sets up test inputs with appropriate shapes for grouped query attention.
- Runs both the TileLang implementation and CPU reference.
- Validates forward and backward outputs with appropriate tolerances.
- Prints a success message after correctness checks pass (line 371), as noted in the PR objectives.
- Benchmarks both implementations for performance comparison.
1-6: LGTM! Imports are appropriate.All necessary dependencies are imported without redundancy.
8-80: LGTM! Forward pass implementation is correct.The GQA-aware forward kernel correctly indexes K/V using
by // groups, allocates buffers with appropriate shapes, and implements the pipelined Flash Attention algorithm with proper causal masking.
83-113: LGTM! Preprocessing kernel correctly computes Delta.The preprocess step correctly computes the row-wise dot product
Delta = sum(O * dO, dim=-1)required for the backward pass.
122-144: LGTM! Postprocessing correctly converts dQ layout and dtype.The postprocess kernel applies the custom layout to dQ and converts from float32 accumulation to float16 output.
147-245: LGTM! Backward kernel logic is correct.The GQA-aware backward kernel correctly:
- Defines grouped gradient shapes with a leading groups dimension for accumulation.
- Indexes K/V using
bx // groupsfor GQA.- Uses
bx % groupsto distribute gradients across groups.- Implements the Flash Attention backward algorithm with proper causal masking and atomic updates for dQ.
254-298: LGTM! Autograd wrapper correctly integrates forward and backward passes.The
_attentionclass correctly:
- Calls the forward kernel and saves necessary tensors.
- Ensures contiguity for backward inputs.
- Allocates grouped gradient buffers matching the kernel's output shapes.
- Sums gradients over the groups dimension (line 297) to produce final dK and dV.
301-301: LGTM! Standard autograd Function alias.
304-326: LGTM! Reference implementation correctly handles GQA.The CPU reference properly expands K and V heads using
repeat_interleaveto match Q's head count, and validates shape consistency with assertions.
8-80: LGTM! Forward kernel correctly implements grouped attention.The forward kernel properly handles grouped query attention:
- Q has
headsheads, K/V havehead_kv = heads // groupsheads (line 14).- Kernel grid iterates over Q heads, mapping to K/V with
by // groups(lines 51, 59).- Buffer shapes are consistent: V_shared is
[block_N, dim_v]matching the corrected shape pattern from the PR.
83-113: LGTM! Backward preprocess correctly computes Delta.The preprocessing kernel computes
Delta[i] = sum_j(O[i,j] * dO[i,j]), which is standard for FlashAttention backward. The implementation properly accumulates across blocks and reduces along dimension 1.
116-119: Layout transformation for atomic accumulation.The
make_dq_layoutfunction creates a swizzled layout for dQ that aligns with 8x8 gemm fragment boundaries. This is necessary because atomic operations can't be vectorized, so the layout must match the fragment structure for efficient atomic adds (line 244).
122-144: LGTM! Postprocess converts accumulated gradients to output format.The postprocess kernel converts dQ from float32 (accumulation dtype) to float16 (output dtype) while applying the layout transformation. This is the final step after atomic accumulation completes.
304-387: LGTM! Comprehensive testing with CPU reference.The reference implementation correctly handles grouped attention by repeating K/V heads (lines 315-316). The main function validates both forward and backward passes with appropriate tolerances for float16, and includes the success message mentioned in the PR objectives (line 371).
| @tilelang.jit(pass_configs={ | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | ||
| }) | ||
| def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): | ||
| sm_scale = (1.0 / dim_qk)**0.5 | ||
| scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) | ||
| head_kv = heads // groups | ||
| q_shape = [batch, seq_len, heads, dim_qk] | ||
| k_shape = [batch, seq_len, head_kv, dim_qk] | ||
| v_shape = [batch, seq_len, head_kv, dim_v] | ||
| dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel | ||
| dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel | ||
| dtype = "float16" | ||
| accum_dtype = "float" | ||
|
|
||
| @T.prim_func | ||
| def flash_bwd( | ||
| Q: T.Tensor(q_shape, dtype), # type: ignore | ||
| K: T.Tensor(k_shape, dtype), # type: ignore | ||
| V: T.Tensor(v_shape, dtype), # type: ignore | ||
| dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore | ||
| lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore | ||
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore | ||
| dQ: T.Tensor(q_shape, accum_dtype), # type: ignore | ||
| dK: T.Tensor(dk_shape, dtype), # type: ignore | ||
| dV: T.Tensor(dv_shape, dtype), # type: ignore | ||
| ): | ||
| with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): | ||
| K_shared = T.alloc_shared([block_M, dim_qk], dtype) | ||
| dsT_shared = T.alloc_shared([block_M, block_N], dtype) | ||
| q = T.alloc_shared([block_N, dim_qk], dtype) | ||
| V_shared = T.alloc_shared([block_M, dim_v], dtype) | ||
| qkT = T.alloc_fragment([block_M, block_N], accum_dtype) | ||
| dsT = T.alloc_fragment([block_M, block_N], accum_dtype) | ||
| qkT_cast = T.alloc_fragment([block_M, block_N], dtype) | ||
| dsT_cast = T.alloc_fragment([block_M, block_N], dtype) | ||
| lse_shared = T.alloc_shared([block_N], accum_dtype) | ||
| delta = T.alloc_shared([block_N], accum_dtype) | ||
| do = T.alloc_shared([block_N, dim_v], dtype) | ||
| dv = T.alloc_fragment([block_M, dim_v], accum_dtype) | ||
| dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) | ||
| dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) | ||
| dv_shared = T.alloc_shared([block_M, dim_v], dtype) | ||
| dk_shared = T.alloc_shared([block_M, dim_qk], dtype) | ||
|
|
||
| T.annotate_layout({ | ||
| dQ: make_dq_layout(dQ), | ||
| K_shared: tilelang.layout.make_swizzled_layout(K_shared), | ||
| dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), | ||
| dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), | ||
| }) | ||
|
|
||
| T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) | ||
| T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) | ||
| T.clear(dv) | ||
| T.clear(dk) | ||
| loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 | ||
| loop_ed = T.ceildiv(seq_len, block_N) | ||
| for k in T.Pipelined(loop_st, loop_ed, num_stages=2): | ||
| T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) | ||
| T.clear(qkT) | ||
| T.gemm( | ||
| K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) | ||
| T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) | ||
| T.clear(dsT) | ||
| T.gemm( | ||
| V_shared, | ||
| do, | ||
| dsT, | ||
| transpose_B=True, | ||
| policy=T.GemmWarpPolicy.FullRow, | ||
| wg_wait=-1) | ||
| T.wait_wgmma(1) | ||
|
|
||
| T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) | ||
| if is_causal: | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], | ||
| 0) | ||
| T.wait_wgmma(0) | ||
| T.copy(qkT, qkT_cast) | ||
| T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) | ||
|
|
||
| T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) | ||
|
|
||
| for i, j in T.Parallel(block_M, block_N): | ||
| dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale | ||
| T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) | ||
|
|
||
| T.copy(dsT_cast, dsT_shared) | ||
| T.clear(dq) | ||
| T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) | ||
| T.wait_wgmma(0) | ||
| for i, j in T.Parallel(block_N, dim_qk): | ||
| if k * block_N + i < seq_len: | ||
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | ||
|
|
||
| T.copy(dv, dv_shared) | ||
| T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | ||
| T.copy(dk, dk_shared) | ||
| T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | ||
|
|
||
| return flash_bwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify gradient copy consistency at line 249.
The backward kernel implements grouped gradient accumulation correctly by using a leading groups dimension in dK/dV shapes (lines 157-158) and indexing with bx % groups (lines 247, 249). This avoids atomics for K/V gradients.
However, line 249 copies dk directly while line 247 copies dv_shared (after line 246 converts dv to dv_shared). Line 248 also converts dk to dk_shared. Should line 249 copy dk_shared instead of dk for consistency?
🏁 Script executed:
#!/bin/bash
# Check if there's a pattern in other backward implementations
rg -n -C 3 'T\.copy\(dk' examples/flash_attention/Length of output: 4094
Correct the K-gradient copy to use dk_shared.
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (line 249), change:
- T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
+ T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])This matches the pattern in other kernels and ensures the shared buffer is written out correctly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @tilelang.jit(pass_configs={ | |
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | |
| }) | |
| def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): | |
| sm_scale = (1.0 / dim_qk)**0.5 | |
| scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) | |
| head_kv = heads // groups | |
| q_shape = [batch, seq_len, heads, dim_qk] | |
| k_shape = [batch, seq_len, head_kv, dim_qk] | |
| v_shape = [batch, seq_len, head_kv, dim_v] | |
| dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel | |
| dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel | |
| dtype = "float16" | |
| accum_dtype = "float" | |
| @T.prim_func | |
| def flash_bwd( | |
| Q: T.Tensor(q_shape, dtype), # type: ignore | |
| K: T.Tensor(k_shape, dtype), # type: ignore | |
| V: T.Tensor(v_shape, dtype), # type: ignore | |
| dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore | |
| lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore | |
| Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore | |
| dQ: T.Tensor(q_shape, accum_dtype), # type: ignore | |
| dK: T.Tensor(dk_shape, dtype), # type: ignore | |
| dV: T.Tensor(dv_shape, dtype), # type: ignore | |
| ): | |
| with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): | |
| K_shared = T.alloc_shared([block_M, dim_qk], dtype) | |
| dsT_shared = T.alloc_shared([block_M, block_N], dtype) | |
| q = T.alloc_shared([block_N, dim_qk], dtype) | |
| V_shared = T.alloc_shared([block_M, dim_v], dtype) | |
| qkT = T.alloc_fragment([block_M, block_N], accum_dtype) | |
| dsT = T.alloc_fragment([block_M, block_N], accum_dtype) | |
| qkT_cast = T.alloc_fragment([block_M, block_N], dtype) | |
| dsT_cast = T.alloc_fragment([block_M, block_N], dtype) | |
| lse_shared = T.alloc_shared([block_N], accum_dtype) | |
| delta = T.alloc_shared([block_N], accum_dtype) | |
| do = T.alloc_shared([block_N, dim_v], dtype) | |
| dv = T.alloc_fragment([block_M, dim_v], accum_dtype) | |
| dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) | |
| dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) | |
| dv_shared = T.alloc_shared([block_M, dim_v], dtype) | |
| dk_shared = T.alloc_shared([block_M, dim_qk], dtype) | |
| T.annotate_layout({ | |
| dQ: make_dq_layout(dQ), | |
| K_shared: tilelang.layout.make_swizzled_layout(K_shared), | |
| dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), | |
| dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), | |
| }) | |
| T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) | |
| T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) | |
| T.clear(dv) | |
| T.clear(dk) | |
| loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 | |
| loop_ed = T.ceildiv(seq_len, block_N) | |
| for k in T.Pipelined(loop_st, loop_ed, num_stages=2): | |
| T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) | |
| T.clear(qkT) | |
| T.gemm( | |
| K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) | |
| T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) | |
| T.clear(dsT) | |
| T.gemm( | |
| V_shared, | |
| do, | |
| dsT, | |
| transpose_B=True, | |
| policy=T.GemmWarpPolicy.FullRow, | |
| wg_wait=-1) | |
| T.wait_wgmma(1) | |
| T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) | |
| for i, j in T.Parallel(block_M, block_N): | |
| qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) | |
| if is_causal: | |
| for i, j in T.Parallel(block_M, block_N): | |
| qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], | |
| 0) | |
| T.wait_wgmma(0) | |
| T.copy(qkT, qkT_cast) | |
| T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) | |
| T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) | |
| for i, j in T.Parallel(block_M, block_N): | |
| dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale | |
| T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) | |
| T.copy(dsT_cast, dsT_shared) | |
| T.clear(dq) | |
| T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) | |
| T.wait_wgmma(0) | |
| for i, j in T.Parallel(block_N, dim_qk): | |
| if k * block_N + i < seq_len: | |
| T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) | |
| T.copy(dv, dv_shared) | |
| T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | |
| T.copy(dk, dk_shared) | |
| T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | |
| return flash_bwd | |
| T.copy(dv, dv_shared) | |
| T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | |
| T.copy(dk, dk_shared) | |
| T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 147
to 251, the final write-back for the K gradients incorrectly copies from the
fragment 'dk' instead of the shared buffer 'dk_shared'; change the last T.copy
call so you first copy the fragment into dk_shared (keep T.copy(dk, dk_shared))
and then write dk_shared into dK (replace T.copy(dk, dK[...]) with
T.copy(dk_shared, dK[...])) ensuring the same index expression used for dV is
applied to dK.
| T.copy(dv, dv_shared) | ||
| T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | ||
| T.copy(dk, dk_shared) | ||
| T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistency in dK copy operation.
Line 247 correctly copies from dv_shared, but line 249 copies from dk (fragment) instead of dk_shared. This is inconsistent with the dV copy pattern and may be a bug.
Apply this diff to fix the inconsistency:
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
-T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
+T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| T.copy(dv, dv_shared) | |
| T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | |
| T.copy(dk, dk_shared) | |
| T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | |
| T.copy(dv, dv_shared) | |
| T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) | |
| T.copy(dk, dk_shared) | |
| T.copy(dk_shared, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 246
to 249, the dK copy is inconsistent: it currently copies from the fragment
variable `dk` instead of the shared buffer `dk_shared`, unlike the dv/dV
pattern. Change the T.copy source on the dK line to use dk_shared and keep the
same destination indexing as the dV copy (i.e., write into dK[bx % groups, bz,
by * block_M:(by + 1) * block_M, bx // groups, :]) so the shared-buffer value is
stored into the global dK tensor.
| return flash_bwd | ||
|
|
||
|
|
||
| @torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify torch.compile decorator on class.
@torch.compile is typically applied to functions or methods, not class definitions. Applying it to _attention class may not have the intended effect. If the goal is to compile the forward/backward methods, consider applying @torch.compile to individual methods or removing it if compilation happens via the attention = _attention.apply call at runtime.
🌐 Web query:
Can torch.compile decorator be applied to a torch.autograd.Function class in PyTorch 2.8?
💡 Result:
Short answer: No — you should not decorate a torch.autograd.Function class with @torch.compile in PyTorch 2.8. torch.compile targets callables (plain Python functions or nn.Module.forward) and the compiled-autograd machinery rewrites/handles custom autograd.Functions at compile time; you should compile the caller (or the .apply wrapper), not the Function subclass itself. See PyTorch docs and the known limitations/bugs (e.g. tensor-subclass + save_for_backward) for details. [1][2][3]
Sources:
- torch.compile API (callable/module compilation). [1]
- Compiled Autograd / how torch.compile handles custom autograd during forward/backward. [2]
- GitHub issue noting unsupported combos with torch.autograd.Function and torch.compile. [3]
Remove the @torch.compile decorator from the custom autograd Function class (line 254).
torch.compile only supports plain functions or nn.Module.forward, not torch.autograd.Function subclasses; apply compilation to the .apply wrapper or the caller instead.
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around line 254,
the @torch.compile decorator is incorrectly applied to a custom
torch.autograd.Function subclass; remove the decorator from the class definition
and instead apply torch.compile to either the Function.apply wrapper or to the
caller (the function or nn.Module.forward that invokes .apply). Update the file
by deleting the @torch.compile line above the autograd.Function class and, if
compilation is desired, wrap the call-site or a thin wrapper function that calls
YourFunction.apply with torch.compile.
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--batch', type=int, default=8, help='Batch size') | ||
| parser.add_argument('--h', type=int, default=32, help='Number of heads') | ||
| parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') | ||
| parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') | ||
| parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') | ||
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') | ||
| parser.add_argument('--groups', type=int, default=16, help='groups') | ||
| args = parser.parse_args() | ||
| main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix argparse boolean flag handling.
Line 396 uses type=bool, which doesn't work as expected with argparse. Any non-empty string (including "False") will be converted to True.
Apply this diff to fix the boolean flag:
- parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+ parser.add_argument('--causal', action='store_true', help='Causal flag')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--batch', type=int, default=8, help='Batch size') | |
| parser.add_argument('--h', type=int, default=32, help='Number of heads') | |
| parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') | |
| parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') | |
| parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') | |
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') | |
| parser.add_argument('--groups', type=int, default=16, help='groups') | |
| args = parser.parse_args() | |
| main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--batch', type=int, default=8, help='Batch size') | |
| parser.add_argument('--h', type=int, default=32, help='Number of heads') | |
| parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') | |
| parser.add_argument('--d_head_qk',type=int, default=192, help='Head dimension for Q/K') | |
| parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') | |
| parser.add_argument('--causal', action='store_true', help='Causal flag') | |
| parser.add_argument('--groups', type=int, default=16, help='groups') | |
| args = parser.parse_args() | |
| main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) |
| parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') | ||
| parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') | ||
| parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') | ||
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix argparse boolean flag.
Using type=bool for the --causal flag is incorrect. In argparse, type=bool will convert any non-empty string (including "False") to True. Use action='store_true' instead, consistent with other examples in the codebase (e.g., example_mha_fwd_bshd_wgmma_pipelined.py line 220).
Apply this diff:
- parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+ parser.add_argument('--causal', action='store_true', help='Causal flag')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| parser.add_argument('--causal', type=bool, default=False, help='Causal flag') | |
| parser.add_argument('--causal', action='store_true', help='Causal flag') |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around line 396,
the argparse declaration uses type=bool which treats any non-empty string as
True; replace that argument with a boolean flag using action='store_true'
(remove type and default), e.g. change to use action='store_true' with the same
help text so --causal sets True and absence leaves it False, matching the other
examples.
For #917 and #902.
This pull request introduces improvements and bug fixes to the FlashAttention examples, primarily focused on the backward pass implementation for grouped query attention (GQA) and pipelined WMMA kernels. The most significant changes include correcting buffer shapes for shared memory usage, updating the accumulation and reduction logic for gradients, and adding a new pipelined WMMA implementation for GQA backward pass. These updates improve correctness, performance, and maintain consistency across the codebase.
Grouped Query Attention (GQA) Backward Pass Improvements:
dKanddVgradient buffers in the backward kernel and their initialization, ensuring proper accumulation and reduction across groups. [1] [2] [3]dv_sharedanddk_sharedto useblock_Minstead ofblock_N, fixing a bug in memory access and improving correctness.dKanddV, matching the new buffer shapes and improving performance.New Pipelined WMMA GQA Backward Example:
example_gqa_bwd_wgmma_pipelined.py, a new example customized for Hopper implementing the GQA backward pass using pipelined WGMMA kernels. This includes forward, backward, and reference implementations, as well as benchmarking and correctness checks.Forward Pass Buffer Shape Fixes:
V_sharedin the macroMMA1from[block_M, dim]to[block_N, dim]in multiple files, ensuring consistency and correctness in the forward pass for both GQA and MHA examples. [1] [2] [3] [4]