Skip to content

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Oct 20, 2025

  • Add GQA backward kernel with varlen input
  • Fix the bug of the original GQA backward kernel with split DK&DV

TODO

  • Check the correctness of split version
  • Check the correctness of atomic add version
  • Vectorize atomic add to optimize the kernel

Summary by CodeRabbit

  • New Features

    • Added flash attention backward implementation example supporting variable-length sequences
    • Included PyTorch autograd integration for efficient gradient computation
    • Added performance profiling and correctness validation capabilities
  • Chores

    • Updated flash attention forward example with minor optimizations

@tzj-fxz tzj-fxz requested a review from chengyupku October 20, 2025 14:23
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Walkthrough

Adds a comprehensive flash attention implementation with forward and backward passes using TileLang, including two backward variants (atomic-add and split strategies), PyTorch autograd integration, reference implementation for validation, and benchmarking infrastructure. Removes a cache-disable call from an existing example.

Changes

Cohort / File(s) Summary
Flash Attention Implementation
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
New file: Complete CUDA-tuned flash attention pipeline with flashattn_fwd (forward kernel with tiled GEMM and causal masking), flashattn_bwd_preprocess/postprocess (pre/post-processing stages), flashattn_bwd_atomic_add and flashattn_bwd_split (two backward variants with per-block tiling), _attention (PyTorch autograd.Function wrapper), ref_program (reference PyTorch implementation for correctness checks), and main CLI/benchmark entry point with validation and profiling.
Example Cleanup
examples/flash_attention/example_gqa_fwd_varlen.py
Removed tilelang.disable_cache() call.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant _attention as _attention.apply
    participant Preprocess
    participant FwdKernel as flashattn_fwd
    participant BwdPre as flashattn_bwd_preprocess
    participant BwdKernel as flashattn_bwd<br/>(atomic or split)
    participant BwdPost as flashattn_bwd_postprocess
    participant Output

    User->>_attention: q, k, v + seq_info
    _attention->>Preprocess: Pad/prepare inputs
    Preprocess->>FwdKernel: Forward pass
    FwdKernel-->>_attention: Output, lse
    _attention-->>User: Output (forward done)
    
    note over User: Backward triggered by loss
    User->>_attention: gradient (do)
    _attention->>BwdPre: Compute delta
    BwdPre-->>_attention: Delta
    alt use_atomic=True
        _attention->>BwdKernel: Atomic-add variant
    else use_atomic=False
        _attention->>BwdKernel: Split variant
    end
    BwdKernel-->>_attention: dQ, dK, dV (accum dtype)
    _attention->>BwdPost: Cast to output dtype
    BwdPost-->>Output: dQ, dK, dV (output dtype)
    Output-->>User: Gradients
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

This includes a large new file with dense kernel logic (forward and two backward variants), tiling/masking strategies, autograd integration with padding/unpadding, reference implementation, and validation scaffolding. While self-contained and well-structured, the multiple kernel variants and attention-specific compute patterns demand careful per-section review.

Possibly related PRs

  • tile-ai/tilelang#943: Consolidates dual backward kernels (atomic/split) into a single variant, directly modifying the same backward-kernel selection and _attention API surface.
  • tile-ai/tilelang#1065: Simplifies flashattn_bwd variants and related helpers, competing with this PR's expansion of atomic/split backward paths.
  • tile-ai/tilelang#1010: Implements TileLang attention backward kernels with atomic-add accumulation and shared-buffer strategies, sharing core kernel design patterns.

Suggested reviewers

  • LeiWang1999

Poem

🐰 Kernels bloom in tiles so bright,
Forward, backward—day and night,
Atomic spins and splits aligned,
Attention woven, fused, refined!
TileLang magic, GPU's delight!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[Feature] Add GQA backward kernel with varlen input" directly and accurately describes the primary change in the changeset. The main contribution is the addition of a comprehensive GQA backward kernel implementation with variable-length input support in the new file example_gqa_bwd_tma_reduce_varlen.py, which includes multiple backward variants (atomic add and split), forward pass, and autograd integration. The title is concise, specific, and uses clear terminology that would allow a teammate reviewing history to immediately understand what was added. The secondary change (removing a disable_cache call in another file) is minor and doesn't warrant mention in the title.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (5)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (5)

646-653: Sum split dK/dV in float32 to avoid precision loss

The per-group partials are fp16; summing in fp16 can degrade gradients. Accumulate in fp32 then cast.

Apply this diff:

-            dk, dv = dk.sum(0), dv.sum(0)
+            dk = dk.to(torch.float32).sum(0).to(k.dtype)
+            dv = dv.to(torch.float32).sum(0).to(v.dtype)

531-539: Leverage T.copy with swizzled shared tiles for coalesced writes

You allocate dv_shared/dk_shared and annotate swizzles but then store elementwise from dv/dk. Use T.copy from the swizzled shared tiles to the global slice to unlock TMA/vectorized stores.

Apply this diff:

-            T.copy(dv, dv_shared)
-            for i, d in T.Parallel(block_M, dim_v):
-                if by * block_M + i < k_current_seqlen:
-                    dV[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dv[i, d]
-            T.copy(dk, dk_shared)
-            for i, d in T.Parallel(block_M, dim_qk):
-                if by * block_M + i < k_current_seqlen:
-                    dK[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dk[i, d]
+            T.copy(dv, dv_shared)
+            T.copy(
+                dv_shared,
+                dV[bx % groups,
+                   k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M,
+                   bx // groups, :])
+            T.copy(dk, dk_shared)
+            T.copy(
+                dk_shared,
+                dK[bx % groups,
+                   k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M,
+                   bx // groups, :])

221-229: Layout annotations commented out; either remove helper or re-enable where needed

make_dq_layout exists but isn’t used. To reduce confusion, either drop the helper or re-enable layout annotations where TMA reduction benefits apply (and shapes match varlen design).

Also applies to: 298-303


156-162: Rename ambiguous variable O (E741)

Use out/output instead of O in function args and locals.

Example diffs:

-            O: T.Tensor(shape, dtype),  # type: ignore
+            Out: T.Tensor(shape, dtype),  # type: ignore
-                        o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j]
+                        o[i, j] = Out[q_start_idx + by * blk + i, bx, k * blk + j]
-    O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
+    Out = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
                   max_seqlen_k, causal, groups, use_atomic)
-    O.backward(dO, retain_graph=True)
+    Out.backward(dO, retain_graph=True)

Also applies to: 729-746


371-385: Atomic adds: consider staging to shared and issuing fewer wider atomics

Hot loops issue 1 atomic per element. Accumulate per-warp rows into shared, then issue strided/wider atomics (e.g., float2/float4 if supported) to reduce contention.

I can sketch a warp-collapsed pattern if you share target vector width constraints for T.atomic_add.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bc37ea6 and 7f5d59a.

📒 Files selected for processing (2)
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1 hunks)
  • examples/flash_attention/example_gqa_fwd_varlen.py (0 hunks)
💤 Files with no reviewable changes (1)
  • examples/flash_attention/example_gqa_fwd_varlen.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (3)
examples/flash_attention/bert_padding.py (2)
  • pad_input (201-213)
  • unpad_input (100-124)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (15)
  • flashattn_fwd (13-81)
  • flashattn_bwd_preprocess (88-114)
  • flash_bwd_prep (95-112)
  • make_dq_layout (117-119)
  • flashattn_bwd_postprocess (126-157)
  • flash_bwd_post (135-155)
  • flashattn_bwd_atomic_add (163-262)
  • flash_bwd (184-260)
  • flash_bwd (291-365)
  • flashattn_bwd_split (268-367)
  • _attention (371-450)
  • forward (374-384)
  • backward (387-450)
  • ref_program (456-478)
  • main (481-539)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (258-299)
🪛 Ruff (0.14.0)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py

158-158: Ambiguous variable name: O

(E741)


196-196: Ambiguous variable name: l

(E741)


591-591: Unpacked variable seqlens_k is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


729-729: Ambiguous variable name: O

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)

101-113: Confirm GEMM semantics; causal/non-causal mask placement may be ineffective

You initialize acc_s with {0,-inf} before T.gemm(Q_shared, K_shared, acc_s, ...). If GEMM overwrites C (not accumulates), the mask is lost before reduce_max. Please confirm TileLang’s GEMM does C = A@B + C. If not, move masking to after GEMM and before the max/exp.

If overwrite semantics, minimally do:

-                if is_causal:
-                    for i, j in T.Parallel(block_M, block_N):
-                        acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
-                                                     (bx * block_M + i < q_current_seqlen and
-                                                      k * block_N + j < k_current_seqlen), 0,
-                                                     -T.infinity(acc_s.dtype))
-                else:
-                    for i, j in T.Parallel(block_M, block_N):
-                        acc_s[i, j] = T.if_then_else(
-                            bx * block_M + i < q_current_seqlen and
-                            k * block_N + j < k_current_seqlen, 0, -T.infinity(acc_s.dtype))
-                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+                T.clear(acc_s)
+                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+                for i, j in T.Parallel(block_M, block_N):
+                    valid = (bx * block_M + i < q_current_seqlen) and (k * block_N + j < k_current_seqlen)
+                    causal_ok = (not is_causal) or (bx * block_M + i >= k * block_N + j)
+                    acc_s[i, j] = T.if_then_else(valid and causal_ok, acc_s[i, j], -T.infinity(acc_s.dtype))

Also applies to: 118-131

Comment on lines +194 to +197
def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix ambiguous lambda variable name (E741)

Rename l to a descriptive name to avoid E741 and improve readability.

Apply this diff:

-def make_dq_layout(dQ):
+def make_dq_layout(dQ):
     # bshd -> bhld to use tma reduction instruction
-    return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
+    return T.Layout(dQ.shape, lambda b, seq, h, d: [b, h, seq, d])
🧰 Tools
🪛 Ruff (0.14.0)

196-196: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py around lines
194 to 197, the lambda in make_dq_layout uses an ambiguous variable name `l`
which triggers E741 and reduces readability; rename `l` to a descriptive name
such as `seq_idx` (or `length_idx`) and update the lambda signature and body to
use that new name so the layout mapping remains [b, h, seq_idx, d] instead of
[b, h, l, d].

Comment on lines +543 to +544
@torch.compile
class _attention(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not decorate an autograd.Function class with torch.compile

Decorating the class breaks .apply and can invalidate autograd. Remove the decorator; if you want compilation, wrap the callable instead (e.g., compile a lightweight wrapper, not the Function subclass).

Apply this diff:

-@torch.compile
 class _attention(torch.autograd.Function):
     ...
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.compile
class _attention(torch.autograd.Function):
class _attention(torch.autograd.Function):
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py around lines
543-544, the @torch.compile decorator is applied to an autograd.Function
subclass which breaks .apply and can invalidate autograd; remove the
@torch.compile decorator from the class declaration so the Function subclass
remains undecorated, and if you need JIT/torch.compile apply it to a small
wrapper function that calls _attention.apply (e.g., define a lightweight
function that invokes _attention.apply and decorate that wrapper with
torch.compile) ensuring all callsites use the wrapper if compilation is desired.

Comment on lines +591 to +604
q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
do_unpad, _, _, _ = unpad_input(
do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
total_q, H, D_HEAD_QK = q.shape
total_kv, HEAD_KV, D_HEAD_V = v.shape
groups = H // HEAD_KV
BATCH = len(cu_seqlens_q) - 1

def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x

do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Drop unused seqlens_k from saved_tensors (silence RUF059)

It’s saved but never used in backward. Rename to underscore at unpack or stop saving it.

Apply this diff:

-        q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
+        q, k, v, o, lse, seqlens_q, _seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors

Optionally, also remove seqlens_k from ctx.save_for_backward(...).

🧰 Tools
🪛 Ruff (0.14.0)

591-591: Unpacked variable seqlens_k is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py around lines
591 to 604, seqlens_k is saved in ctx.saved_tensors but never used in the
backward code which raises a lint warning; update the unpacking to ignore it
(rename seqlens_k to _ or use an extra underscore) or remove seqlens_k from the
values passed to ctx.save_for_backward at the forward pass so it is not saved at
all, and ensure subsequent code and indices align after that change.

@LeiWang1999
Copy link
Member

LeiWang1999 commented Oct 20, 2025

Does this pr ready for review? @tzj-fxz

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Oct 21, 2025

Does this pr ready for review? @tzj-fxz

Yes, while the further optimization depends on the atomic add bug fix pr #1081.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants