Skip to content
167 changes: 103 additions & 64 deletions examples/linear_attention/example_linear_attn_bwd.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import torch
import tilelang as tl
import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench

import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
from fla.modules.l2norm import l2norm_fwd
from einops import rearrange
from typing import Optional, Tuple


@tl.jit(
out_idx=[4, 5, 6],
@tilelang.jit(
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def chunk_linear_attn_bwd_kernel(
def tl_fused_chunk_bwd_kernel(
B,
S,
H,
Expand All @@ -30,19 +31,19 @@ def chunk_linear_attn_bwd_kernel(
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK)
NV = tl.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size)
NK = tilelang.cdiv(DK, BK)
NV = tilelang.cdiv(DV, BV)
NT = tilelang.cdiv(S, chunk_size)

@T.prim_func
def chunk_linear_attn_bwd(
def fused_chunk_linear_attn_bwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
dO: T.Tensor([B, S, H, DV], dtype), # type: ignore
dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore
dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore
dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore
):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
Expand All @@ -51,8 +52,11 @@ def chunk_linear_attn_bwd(
ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
dq = T.alloc_fragment([chunk_size, BK], accum_dtype)
dq_shared = T.alloc_shared([chunk_size, BK], accum_dtype)
dk = T.alloc_fragment([chunk_size, BK], accum_dtype)
dk_shared = T.alloc_shared([chunk_size, BK], accum_dtype)
dv = T.alloc_fragment([chunk_size, BV], accum_dtype)
dv_shared = T.alloc_shared([chunk_size, BV], accum_dtype)
q = T.alloc_shared([chunk_size, BK], dtype)
k = T.alloc_shared([chunk_size, BK], dtype)
v = T.alloc_shared([chunk_size, BV], dtype)
Expand All @@ -61,22 +65,19 @@ def chunk_linear_attn_bwd(
h_shared = T.alloc_shared([BV, BK], dtype)
dh = T.alloc_fragment([BK, BV], accum_dtype)
dh_shared = T.alloc_shared([BK, BV], dtype)
T.clear(h)
T.clear(dh)

T.annotate_layout({
ds_shared: tl.layout.make_swizzled_layout(ds_shared),
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
do: tl.layout.make_swizzled_layout(do),
h_shared: tl.layout.make_swizzled_layout(h_shared),
dh_shared: tl.layout.make_swizzled_layout(dh_shared)
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared)
})
T.use_swizzle(10)

T.clear(h)
T.clear(dh)

# Calculate dQ
for i in T.Pipelined(0, NT, num_stages=1):
for i in T.Pipelined(0, NT):
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
Expand All @@ -92,12 +93,13 @@ def chunk_linear_attn_bwd(
T.gemm(v, k, h, transpose_A=True)
for row, col in T.Parallel(chunk_size, BK):
dq[row, col] *= scale
T.copy(
dq, dQ[i_v, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
i_k * BK:(i_k + 1) * BK])
T.copy(dq, dq_shared)
T.atomic_add(
dQ[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK],
dq_shared)

# Calculate dK, dV (reversely)
for i in T.Pipelined(1, NT + 1, num_stages=1):
for i in T.Pipelined(1, NT + 1):
start = NT - i
for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale
Expand Down Expand Up @@ -131,53 +133,90 @@ def chunk_linear_attn_bwd(
# Update dh
T.gemm(q, do, dh, transpose_A=True)

T.copy(
dk, dK[i_v, i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_k * BK:(i_k + 1) * BK])
T.copy(
dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV])

return chunk_linear_attn_bwd


def postprocess(dQ, dK, dV):
dQ = dQ[0] if dQ.size(0) == 1 else dQ.sum(0)
dK = dK[0] if dK.size(0) == 1 else dK.sum(0)
dV = dV[0] if dV.size(0) == 1 else dV.sum(0)
return dQ, dK, dV


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=256, help='Head dim')
args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D

T.copy(dk, dk_shared)
T.atomic_add(
dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_k * BK:(i_k + 1) * BK], dk_shared)
T.copy(dv, dv_shared)
T.atomic_add(
dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], dv_shared)
Comment on lines +136 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Gate benchmarks and keep test path fast; mirror forward changes

Prevent heavy benchmarking during tests and allow enabling via CLI.

-                T.atomic_add(
-                    dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
-                       i_k * BK:(i_k + 1) * BK], dk_shared)
+                T.atomic_add(
+                    dK[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
+                       i_k * BK:(i_k + 1) * BK], dk_shared)
@@
-def main(B=1, S=1024, H=16, D=128):
+def main(B=1, S=1024, H=16, D=128, run_bench: bool = False):
@@
-    print('Passed all tests!✅')
-
-    # Benchmark
-    q.grad = k.grad = v.grad = None
-    o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
-    t1 = do_bench(
-        lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100, backend='cupti')
-    t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), warmup=25, rep=100, backend='cupti')
-    print(f'Triton latency: {t1:.3f} ms')
-    print(f'TileLang latency: {t2:.3f} ms')
-    print(f'Speedup: {t1/t2:.3f}x')
+    print('Passed all tests!✅')
+    if run_bench:
+        # Benchmark
+        q.grad = k.grad = v.grad = None
+        o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
+        t1 = do_bench(
+            lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100, backend='cupti')
+        t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), warmup=25, rep=100, backend='cupti')
+        print(f'Triton latency: {t1:.3f} ms')
+        print(f'TileLang latency: {t2:.3f} ms')
+        print(f'Speedup: {t1/t2:.3f}x')
@@
-    main(args.B, args.S, args.H, args.D)
+    main(args.B, args.S, args.H, args.D, run_bench=True)

Also applies to: 180-209, 213-221

🤖 Prompt for AI Agents
In examples/linear_attention/example_linear_attn_bwd.py around lines 134-141
(and similarly lines 180-209, 213-221), the test path currently runs heavy gate
benchmarks; update the code to skip or short-circuit expensive benchmarking
during test runs and expose a CLI flag to enable full benchmarks. Modify the
logic to check a new command-line/single-run flag (e.g., --run-bench) or an
environment variable before executing benchmark code paths so tests use the
lightweight forward-mirror path by default; ensure the forward changes that
mirror behavior are applied consistently in the referenced blocks and that the
default test execution remains fast while full benchmarking is opt-in via the
CLI flag.


return fused_chunk_linear_attn_bwd


def tl_fused_chunk_bwd(Q, K, V, dO):
B, S, H, D = Q.shape
kernel = tl_fused_chunk_bwd_kernel(B, S, H, D, D)
dQ = torch.zeros_like(Q, dtype=torch.float32)
dK = torch.zeros_like(K, dtype=torch.float32)
dV = torch.zeros_like(V, dtype=torch.float32)
kernel(Q, K, V, dO, dQ, dK, dV)
return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16)


def ref_program(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = q.float(), k.float(), v.float()
if scale is None:
scale = q.shape[-1]**-0.5
chunk_size = 64
q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale
k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size)
v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size)
kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2)
h = kv[:, :, -1, :, :]
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_(
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1),
0)) @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b (n c) h d'), h


def main(B=1, S=1024, H=16, D=128):
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True)
do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)

kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D)
dq, dk, dv = postprocess(*kernel(q, k, v, do))
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
# qk norm is necessary for linear attn
q = l2norm_fwd(q)[0].requires_grad_(True)
k = l2norm_fwd(k)[0].requires_grad_(True)

dq, dk, dv = tl_fused_chunk_bwd(q, k, v, do)
q.grad = k.grad = v.grad = None
o_ref, _ = ref_program(q, k, v)
o_ref.backward(do, retain_graph=True)
if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad):
print('Passed all tests!✅')
else:
print('Failed some tests!❌')
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100)

assert torch.allclose(
dq, q.grad, atol=1e-2, rtol=1e-2), f'dq max err: {(dq - q.grad).abs().max()}'
assert torch.allclose(
dk, k.grad, atol=1e-2, rtol=1e-2), f'dk max err: {(dk - k.grad).abs().max()}'
assert torch.allclose(
dv, v.grad, atol=1e-2, rtol=1e-2), f'dv max err: {(dv - v.grad).abs().max()}'
print('Passed all tests!✅')

# Benchmark
q.grad = k.grad = v.grad = None
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100)
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti')
t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti')
print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x')


if __name__ == '__main__':
main()
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=1024, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
args = parser.parse_args()

main(args.B, args.S, args.H, args.D)
Loading