Skip to content

Commit 557589f

Browse files
authored
[Example] Introduce split+sum template, and optimize atomic_add performance for bwd examples (#940)
* example fix * lint fix * bug fix * reduce test size.
1 parent 95170ab commit 557589f

File tree

11 files changed

+816
-259
lines changed

11 files changed

+816
-259
lines changed

examples/flash_attention/example_gqa_bwd.py

Lines changed: 188 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,118 @@ def flash_bwd_post(
147147
@tilelang.jit(pass_configs={
148148
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
149149
})
150-
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
150+
def flashattn_bwd_atomic_add(batch,
151+
heads,
152+
seq_len,
153+
dim_qk,
154+
dim_v,
155+
is_causal,
156+
block_M,
157+
block_N,
158+
threads=256,
159+
num_stages=2,
160+
groups=1):
161+
sm_scale = (1.0 / dim_qk)**0.5
162+
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
163+
head_kv = heads // groups
164+
q_shape = [batch, seq_len, heads, dim_qk]
165+
k_shape = [batch, seq_len, head_kv, dim_qk]
166+
v_shape = [batch, seq_len, head_kv, dim_v]
167+
dtype = "float16"
168+
accum_dtype = "float"
169+
170+
@T.prim_func
171+
def flash_bwd(
172+
Q: T.Tensor(q_shape, dtype), # type: ignore
173+
K: T.Tensor(k_shape, dtype), # type: ignore
174+
V: T.Tensor(v_shape, dtype), # type: ignore
175+
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
176+
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
177+
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
178+
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
179+
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
180+
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
181+
):
182+
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
183+
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
184+
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
185+
q = T.alloc_shared([block_N, dim_qk], dtype)
186+
V_shared = T.alloc_shared([block_M, dim_v], dtype)
187+
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
188+
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
189+
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
190+
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
191+
lse_shared = T.alloc_shared([block_N], accum_dtype)
192+
delta = T.alloc_shared([block_N], accum_dtype)
193+
do = T.alloc_shared([block_N, dim_v], dtype)
194+
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
195+
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
196+
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
197+
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
198+
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
199+
200+
T.annotate_layout({
201+
dQ: make_dq_layout(dQ),
202+
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
203+
})
204+
205+
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
206+
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
207+
T.clear(dv)
208+
T.clear(dk)
209+
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
210+
loop_ed = T.ceildiv(seq_len, block_N)
211+
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
212+
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
213+
T.clear(qkT)
214+
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
215+
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
216+
for i, j in T.Parallel(block_M, block_N):
217+
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
218+
if is_causal:
219+
for i, j in T.Parallel(block_M, block_N):
220+
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
221+
0)
222+
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
223+
T.clear(dsT)
224+
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
225+
T.copy(qkT, qkT_cast)
226+
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
227+
228+
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
229+
230+
for i, j in T.Parallel(block_M, block_N):
231+
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
232+
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
233+
234+
T.copy(dsT_cast, dsT_shared)
235+
T.clear(dq)
236+
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
237+
for i, j in T.Parallel(block_N, dim_qk):
238+
if k * block_N + i < seq_len:
239+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
240+
T.copy(dv, dv_shared)
241+
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
242+
T.copy(dk, dk_shared)
243+
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)
244+
245+
return flash_bwd
246+
247+
248+
@tilelang.jit(pass_configs={
249+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
250+
})
251+
def flashattn_bwd_split(batch,
252+
heads,
253+
seq_len,
254+
dim_qk,
255+
dim_v,
256+
is_causal,
257+
block_M,
258+
block_N,
259+
threads=256,
260+
num_stages=2,
261+
groups=1):
151262
sm_scale = (1.0 / dim_qk)**0.5
152263
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
153264
head_kv = heads // groups
@@ -171,7 +282,7 @@ def flash_bwd(
171282
dK: T.Tensor(dk_shape, dtype), # type: ignore
172283
dV: T.Tensor(dv_shape, dtype), # type: ignore
173284
):
174-
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
285+
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
175286
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
176287
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
177288
q = T.alloc_shared([block_N, dim_qk], dtype)
@@ -202,20 +313,20 @@ def flash_bwd(
202313
T.clear(dk)
203314
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
204315
loop_ed = T.ceildiv(seq_len, block_N)
205-
for k in T.Pipelined(loop_st, loop_ed, num_stages=1):
316+
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
206317
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
207318
T.clear(qkT)
208319
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
320+
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
321+
T.clear(dsT)
322+
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
209323
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
210324
for i, j in T.Parallel(block_M, block_N):
211325
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
212326
if is_causal:
213327
for i, j in T.Parallel(block_M, block_N):
214328
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
215329
0)
216-
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
217-
T.clear(dsT)
218-
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
219330
T.copy(qkT, qkT_cast)
220331
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
221332

@@ -244,7 +355,7 @@ def flash_bwd(
244355
class _attention(torch.autograd.Function):
245356

246357
@staticmethod
247-
def forward(ctx, q, k, v, causal, groups=1):
358+
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
248359
BATCH, N_CTX, H, D_HEAD_QK = q.shape
249360
D_HEAD_V = v.shape[-1]
250361
block_M = 128
@@ -253,6 +364,7 @@ def forward(ctx, q, k, v, causal, groups=1):
253364
o, lse = mod(q, k, v)
254365
ctx.save_for_backward(q, k, v, o, lse)
255366
ctx.causal = causal
367+
ctx.use_atomic = use_atomic
256368
return o
257369

258370
@staticmethod
@@ -268,23 +380,59 @@ def maybe_contiguous(x):
268380
return x
269381

270382
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
271-
block_M = 64
383+
block_M = 128
272384
block_N = 32
273385
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
274386
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
275387
delta = mod_prep(o, do)
276-
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
277-
groups)
278-
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
279-
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
280-
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
281-
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
282-
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
283-
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
284-
kernel(q, k, v, do, lse, delta, dq, dk, dv)
285-
dq = mod_post(dq)
286-
dk, dv = dk.sum(0), dv.sum(0)
287-
return dq, dk, dv, None, None
388+
389+
if ctx.use_atomic:
390+
kernel = flashattn_bwd_atomic_add(
391+
BATCH,
392+
H,
393+
N_CTX,
394+
D_HEAD_QK,
395+
D_HEAD_V,
396+
ctx.causal,
397+
block_M,
398+
block_N,
399+
threads=256,
400+
num_stages=2,
401+
groups=groups)
402+
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
403+
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
404+
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
405+
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
406+
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
407+
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
408+
kernel(q, k, v, do, lse, delta, dq, dk, dv)
409+
dq = mod_post(dq)
410+
dk = dk.to(torch.float16)
411+
dv = dv.to(torch.float16)
412+
else:
413+
kernel = flashattn_bwd_split(
414+
BATCH,
415+
H,
416+
N_CTX,
417+
D_HEAD_QK,
418+
D_HEAD_V,
419+
ctx.causal,
420+
block_M,
421+
block_N,
422+
threads=256,
423+
num_stages=2,
424+
groups=groups)
425+
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
426+
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
427+
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
428+
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
429+
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
430+
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
431+
kernel(q, k, v, do, lse, delta, dq, dk, dv)
432+
dq = mod_post(dq)
433+
dk, dv = dk.sum(0), dv.sum(0)
434+
435+
return dq, dk, dv, None, None, None
288436

289437

290438
attention = _attention.apply
@@ -321,7 +469,8 @@ def main(BATCH: int = 1,
321469
D_HEAD_QK: int = 192,
322470
D_HEAD_V: int = 128,
323471
groups: int = 16,
324-
causal: bool = False):
472+
causal: bool = False,
473+
use_atomic: bool = True):
325474
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
326475
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
327476
total_flops = 3 * flops_per_qk + 2 * flops_per_v
@@ -341,7 +490,7 @@ def main(BATCH: int = 1,
341490
dO = (
342491
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
343492
device="cuda").normal_().requires_grad_())
344-
O = attention(Q, K, V, causal, groups)
493+
O = attention(Q, K, V, causal, groups, use_atomic)
345494
O.backward(dO, retain_graph=True)
346495
dQ, Q.grad = Q.grad.clone(), None
347496
dK, K.grad = K.grad.clone(), None
@@ -382,7 +531,22 @@ def run1():
382531
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
383532
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
384533
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
385-
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
534+
parser.add_argument('--causal', action='store_true', help='Causal flag')
386535
parser.add_argument('--groups', type=int, default=16, help='groups')
536+
parser.add_argument(
537+
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
538+
parser.add_argument(
539+
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
387540
args = parser.parse_args()
388-
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
541+
542+
# Handle backward compatibility and logic
543+
if args.use_split:
544+
use_atomic = False
545+
elif args.use_atomic:
546+
use_atomic = True
547+
else:
548+
# Default: use atomic
549+
use_atomic = True
550+
551+
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
552+
use_atomic)

0 commit comments

Comments
 (0)