-
Couldn't load subscription status.
- Fork 286
[Feature] Add GQA backward kernel with varlen input #1082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds a comprehensive flash attention implementation with forward and backward passes using TileLang, including two backward variants (atomic-add and split strategies), PyTorch autograd integration, reference implementation for validation, and benchmarking infrastructure. Removes a cache-disable call from an existing example. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant _attention as _attention.apply
participant Preprocess
participant FwdKernel as flashattn_fwd
participant BwdPre as flashattn_bwd_preprocess
participant BwdKernel as flashattn_bwd<br/>(atomic or split)
participant BwdPost as flashattn_bwd_postprocess
participant Output
User->>_attention: q, k, v + seq_info
_attention->>Preprocess: Pad/prepare inputs
Preprocess->>FwdKernel: Forward pass
FwdKernel-->>_attention: Output, lse
_attention-->>User: Output (forward done)
note over User: Backward triggered by loss
User->>_attention: gradient (do)
_attention->>BwdPre: Compute delta
BwdPre-->>_attention: Delta
alt use_atomic=True
_attention->>BwdKernel: Atomic-add variant
else use_atomic=False
_attention->>BwdKernel: Split variant
end
BwdKernel-->>_attention: dQ, dK, dV (accum dtype)
_attention->>BwdPost: Cast to output dtype
BwdPost-->>Output: dQ, dK, dV (output dtype)
Output-->>User: Gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes This includes a large new file with dense kernel logic (forward and two backward variants), tiling/masking strategies, autograd integration with padding/unpadding, reference implementation, and validation scaffolding. While self-contained and well-structured, the multiple kernel variants and attention-specific compute patterns demand careful per-section review. Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (5)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (5)
646-653: Sum split dK/dV in float32 to avoid precision lossThe per-group partials are fp16; summing in fp16 can degrade gradients. Accumulate in fp32 then cast.
Apply this diff:
- dk, dv = dk.sum(0), dv.sum(0) + dk = dk.to(torch.float32).sum(0).to(k.dtype) + dv = dv.to(torch.float32).sum(0).to(v.dtype)
531-539: Leverage T.copy with swizzled shared tiles for coalesced writesYou allocate
dv_shared/dk_sharedand annotate swizzles but then store elementwise fromdv/dk. Use T.copy from the swizzled shared tiles to the global slice to unlock TMA/vectorized stores.Apply this diff:
- T.copy(dv, dv_shared) - for i, d in T.Parallel(block_M, dim_v): - if by * block_M + i < k_current_seqlen: - dV[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dv[i, d] - T.copy(dk, dk_shared) - for i, d in T.Parallel(block_M, dim_qk): - if by * block_M + i < k_current_seqlen: - dK[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dk[i, d] + T.copy(dv, dv_shared) + T.copy( + dv_shared, + dV[bx % groups, + k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, + bx // groups, :]) + T.copy(dk, dk_shared) + T.copy( + dk_shared, + dK[bx % groups, + k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, + bx // groups, :])
221-229: Layout annotations commented out; either remove helper or re-enable where needed
make_dq_layoutexists but isn’t used. To reduce confusion, either drop the helper or re-enable layout annotations where TMA reduction benefits apply (and shapes match varlen design).Also applies to: 298-303
156-162: Rename ambiguous variableO(E741)Use
out/outputinstead ofOin function args and locals.Example diffs:
- O: T.Tensor(shape, dtype), # type: ignore + Out: T.Tensor(shape, dtype), # type: ignore - o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j] + o[i, j] = Out[q_start_idx + by * blk + i, bx, k * blk + j]- O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, + Out = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic) - O.backward(dO, retain_graph=True) + Out.backward(dO, retain_graph=True)Also applies to: 729-746
371-385: Atomic adds: consider staging to shared and issuing fewer wider atomicsHot loops issue 1 atomic per element. Accumulate per-warp rows into shared, then issue strided/wider atomics (e.g., float2/float4 if supported) to reduce contention.
I can sketch a warp-collapsed pattern if you share target vector width constraints for
T.atomic_add.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py(1 hunks)examples/flash_attention/example_gqa_fwd_varlen.py(0 hunks)
💤 Files with no reviewable changes (1)
- examples/flash_attention/example_gqa_fwd_varlen.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (3)
examples/flash_attention/bert_padding.py (2)
pad_input(201-213)unpad_input(100-124)examples/flash_attention/example_gqa_bwd_tma_reduce.py (15)
flashattn_fwd(13-81)flashattn_bwd_preprocess(88-114)flash_bwd_prep(95-112)make_dq_layout(117-119)flashattn_bwd_postprocess(126-157)flash_bwd_post(135-155)flashattn_bwd_atomic_add(163-262)flash_bwd(184-260)flash_bwd(291-365)flashattn_bwd_split(268-367)_attention(371-450)forward(374-384)backward(387-450)ref_program(456-478)main(481-539)tilelang/contrib/nvcc.py (1)
get_target_compute_version(258-299)
🪛 Ruff (0.14.0)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
158-158: Ambiguous variable name: O
(E741)
196-196: Ambiguous variable name: l
(E741)
591-591: Unpacked variable seqlens_k is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
729-729: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
101-113: Confirm GEMM semantics; causal/non-causal mask placement may be ineffectiveYou initialize
acc_swith {0,-inf} beforeT.gemm(Q_shared, K_shared, acc_s, ...). If GEMM overwrites C (not accumulates), the mask is lost before reduce_max. Please confirm TileLang’s GEMM doesC = A@B + C. If not, move masking to after GEMM and before the max/exp.If overwrite semantics, minimally do:
- if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and - (bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen), 0, - -T.infinity(acc_s.dtype)) - else: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else( - bx * block_M + i < q_current_seqlen and - k * block_N + j < k_current_seqlen, 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_M, block_N): + valid = (bx * block_M + i < q_current_seqlen) and (k * block_N + j < k_current_seqlen) + causal_ok = (not is_causal) or (bx * block_M + i >= k * block_N + j) + acc_s[i, j] = T.if_then_else(valid and causal_ok, acc_s[i, j], -T.infinity(acc_s.dtype))Also applies to: 118-131
| def make_dq_layout(dQ): | ||
| # bshd -> bhld to use tma reduction instruction | ||
| return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix ambiguous lambda variable name (E741)
Rename l to a descriptive name to avoid E741 and improve readability.
Apply this diff:
-def make_dq_layout(dQ):
+def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
- return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
+ return T.Layout(dQ.shape, lambda b, seq, h, d: [b, h, seq, d])🧰 Tools
🪛 Ruff (0.14.0)
196-196: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py around lines
194 to 197, the lambda in make_dq_layout uses an ambiguous variable name `l`
which triggers E741 and reduces readability; rename `l` to a descriptive name
such as `seq_idx` (or `length_idx`) and update the lambda signature and body to
use that new name so the layout mapping remains [b, h, seq_idx, d] instead of
[b, h, l, d].
| @torch.compile | ||
| class _attention(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not decorate an autograd.Function class with torch.compile
Decorating the class breaks .apply and can invalidate autograd. Remove the decorator; if you want compilation, wrap the callable instead (e.g., compile a lightweight wrapper, not the Function subclass).
Apply this diff:
-@torch.compile
class _attention(torch.autograd.Function):
...📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @torch.compile | |
| class _attention(torch.autograd.Function): | |
| class _attention(torch.autograd.Function): |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py around lines
543-544, the @torch.compile decorator is applied to an autograd.Function
subclass which breaks .apply and can invalidate autograd; remove the
@torch.compile decorator from the class declaration so the Function subclass
remains undecorated, and if you need JIT/torch.compile apply it to a small
wrapper function that calls _attention.apply (e.g., define a lightweight
function that invokes _attention.apply and decorate that wrapper with
torch.compile) ensuring all callsites use the wrapper if compilation is desired.
| q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors | ||
| do_unpad, _, _, _ = unpad_input( | ||
| do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) | ||
| total_q, H, D_HEAD_QK = q.shape | ||
| total_kv, HEAD_KV, D_HEAD_V = v.shape | ||
| groups = H // HEAD_KV | ||
| BATCH = len(cu_seqlens_q) - 1 | ||
|
|
||
| def maybe_contiguous(x): | ||
| if x.stride(-1) != 1: | ||
| return x.contiguous() | ||
| return x | ||
|
|
||
| do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop unused seqlens_k from saved_tensors (silence RUF059)
It’s saved but never used in backward. Rename to underscore at unpack or stop saving it.
Apply this diff:
- q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
+ q, k, v, o, lse, seqlens_q, _seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensorsOptionally, also remove seqlens_k from ctx.save_for_backward(...).
🧰 Tools
🪛 Ruff (0.14.0)
591-591: Unpacked variable seqlens_k is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py around lines
591 to 604, seqlens_k is saved in ctx.saved_tensors but never used in the
backward code which raises a lint warning; update the unpacking to ignore it
(rename seqlens_k to _ or use an extra underscore) or remove seqlens_k from the
values passed to ctx.save_for_backward at the forward pass so it is not saved at
all, and ensure subsequent code and indices align after that change.
|
Does this pr ready for review? @tzj-fxz |
TODO
Summary by CodeRabbit
New Features
Chores