Skip to content

Commit 17bd0a6

Browse files
authored
[Enhancement] Deprecate split&sum in attn bwd examples on Hopper and migrate to vectorized atomic add (#1065)
1 parent ae9a6f0 commit 17bd0a6

File tree

2 files changed

+50
-266
lines changed

2 files changed

+50
-266
lines changed

examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py

Lines changed: 43 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -113,51 +113,20 @@ def flash_bwd_prep(
113113
return flash_bwd_prep
114114

115115

116-
def make_dq_layout(dQ):
117-
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
118-
return T.Layout(dQ.shape,
119-
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
120-
121-
122-
@tilelang.jit(
123-
out_idx=[1], pass_configs={
124-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
125-
})
126-
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
127-
dtype = "float16"
128-
accum_dtype = "float"
129-
shape = [batch, seq_len, heads, dim_qk]
130-
blk = 64
131-
132-
@T.prim_func
133-
def flash_bwd_post(
134-
dQ: T.Tensor(shape, accum_dtype), # type: ignore
135-
dQ_out: T.Tensor(shape, dtype), # type: ignore
136-
):
137-
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
138-
T.annotate_layout({dQ: make_dq_layout(dQ)})
139-
T.copy(
140-
dQ[bz, bx * blk:(bx + 1) * blk, by, :],
141-
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
142-
)
143-
144-
return flash_bwd_post
145-
146-
147116
@tilelang.jit(pass_configs={
148117
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
149118
})
150-
def flashattn_bwd_atomic_add(batch,
151-
heads,
152-
seq_len,
153-
dim_qk,
154-
dim_v,
155-
is_causal,
156-
block_M,
157-
block_N,
158-
threads=256,
159-
num_stages=2,
160-
groups=1):
119+
def flashattn_bwd(batch,
120+
heads,
121+
seq_len,
122+
dim_qk,
123+
dim_v,
124+
is_causal,
125+
block_M,
126+
block_N,
127+
threads=256,
128+
num_stages=2,
129+
groups=1):
161130
sm_scale = (1.0 / dim_qk)**0.5
162131
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
163132
head_kv = heads // groups
@@ -196,10 +165,13 @@ def flash_bwd(
196165
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
197166
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
198167
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
168+
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
199169

200170
T.annotate_layout({
201-
dQ: make_dq_layout(dQ),
202171
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
172+
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
173+
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
174+
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
203175
})
204176

205177
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
@@ -244,129 +216,12 @@ def flash_bwd(
244216
T.clear(dq)
245217
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
246218
T.wait_wgmma(0)
247-
for i, j in T.Parallel(block_N, dim_qk):
248-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
219+
T.copy(dq, dq_shared)
220+
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared)
249221
T.copy(dv, dv_shared)
250222
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
251223
T.copy(dk, dk_shared)
252-
for i, j in T.Parallel(block_M, dim_qk):
253-
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
254-
255-
return flash_bwd
256-
257-
258-
@tilelang.jit(pass_configs={
259-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
260-
})
261-
def flashattn_bwd_split(batch,
262-
heads,
263-
seq_len,
264-
dim_qk,
265-
dim_v,
266-
is_causal,
267-
block_M,
268-
block_N,
269-
threads=256,
270-
num_stages=2,
271-
groups=1):
272-
sm_scale = (1.0 / dim_qk)**0.5
273-
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
274-
head_kv = heads // groups
275-
q_shape = [batch, seq_len, heads, dim_qk]
276-
k_shape = [batch, seq_len, head_kv, dim_qk]
277-
v_shape = [batch, seq_len, head_kv, dim_v]
278-
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
279-
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
280-
dtype = "float16"
281-
accum_dtype = "float"
282-
283-
@T.prim_func
284-
def flash_bwd(
285-
Q: T.Tensor(q_shape, dtype), # type: ignore
286-
K: T.Tensor(k_shape, dtype), # type: ignore
287-
V: T.Tensor(v_shape, dtype), # type: ignore
288-
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
289-
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
290-
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
291-
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
292-
dK: T.Tensor(dk_shape, dtype), # type: ignore
293-
dV: T.Tensor(dv_shape, dtype), # type: ignore
294-
):
295-
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
296-
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
297-
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
298-
q = T.alloc_shared([block_N, dim_qk], dtype)
299-
V_shared = T.alloc_shared([block_M, dim_v], dtype)
300-
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
301-
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
302-
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
303-
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
304-
lse_shared = T.alloc_shared([block_N], accum_dtype)
305-
delta = T.alloc_shared([block_N], accum_dtype)
306-
do = T.alloc_shared([block_N, dim_v], dtype)
307-
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
308-
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
309-
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
310-
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
311-
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
312-
313-
T.annotate_layout({
314-
dQ: make_dq_layout(dQ),
315-
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
316-
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
317-
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
318-
})
319-
320-
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
321-
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
322-
T.clear(dv)
323-
T.clear(dk)
324-
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
325-
loop_ed = T.ceildiv(seq_len, block_N)
326-
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
327-
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
328-
T.clear(qkT)
329-
T.gemm(
330-
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
331-
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
332-
T.clear(dsT)
333-
T.gemm(
334-
V_shared,
335-
do,
336-
dsT,
337-
transpose_B=True,
338-
policy=T.GemmWarpPolicy.FullRow,
339-
wg_wait=-1)
340-
T.wait_wgmma(1)
341-
342-
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
343-
for i, j in T.Parallel(block_M, block_N):
344-
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
345-
if is_causal:
346-
for i, j in T.Parallel(block_M, block_N):
347-
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
348-
0)
349-
T.wait_wgmma(0)
350-
T.copy(qkT, qkT_cast)
351-
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
352-
353-
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
354-
355-
for i, j in T.Parallel(block_M, block_N):
356-
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
357-
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
358-
359-
T.copy(dsT_cast, dsT_shared)
360-
T.clear(dq)
361-
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
362-
T.wait_wgmma(0)
363-
for i, j in T.Parallel(block_N, dim_qk):
364-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
365-
366-
T.copy(dv, dv_shared)
367-
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
368-
T.copy(dk, dk_shared)
369-
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
224+
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)
370225

371226
return flash_bwd
372227

@@ -403,54 +258,30 @@ def maybe_contiguous(x):
403258
block_M = 128
404259
block_N = 32
405260
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
406-
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
407261
delta = mod_prep(o, do)
408262

409-
if ctx.use_atomic:
410-
kernel = flashattn_bwd_atomic_add(
411-
BATCH,
412-
H,
413-
N_CTX,
414-
D_HEAD_QK,
415-
D_HEAD_V,
416-
ctx.causal,
417-
block_M,
418-
block_N,
419-
threads=256,
420-
num_stages=2,
421-
groups=groups)
422-
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
423-
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
424-
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
425-
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
426-
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
427-
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
428-
kernel(q, k, v, do, lse, delta, dq, dk, dv)
429-
dq = mod_post(dq)
430-
dk = dk.to(torch.float16)
431-
dv = dv.to(torch.float16)
432-
else:
433-
kernel = flashattn_bwd_split(
434-
BATCH,
435-
H,
436-
N_CTX,
437-
D_HEAD_QK,
438-
D_HEAD_V,
439-
ctx.causal,
440-
block_M,
441-
block_N,
442-
threads=256,
443-
num_stages=2,
444-
groups=groups)
445-
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
446-
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
447-
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
448-
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
449-
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
450-
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
451-
kernel(q, k, v, do, lse, delta, dq, dk, dv)
452-
dq = mod_post(dq)
453-
dk, dv = dk.sum(0), dv.sum(0)
263+
kernel = flashattn_bwd(
264+
BATCH,
265+
H,
266+
N_CTX,
267+
D_HEAD_QK,
268+
D_HEAD_V,
269+
ctx.causal,
270+
block_M,
271+
block_N,
272+
threads=256,
273+
num_stages=2,
274+
groups=groups)
275+
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
276+
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
277+
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
278+
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
279+
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
280+
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
281+
kernel(q, k, v, do, lse, delta, dq, dk, dv)
282+
dq = dq.to(torch.float16)
283+
dk = dk.to(torch.float16)
284+
dv = dv.to(torch.float16)
454285

455286
return dq, dk, dv, None, None, None
456287

@@ -489,8 +320,7 @@ def main(BATCH: int = 1,
489320
D_HEAD_QK: int = 192,
490321
D_HEAD_V: int = 128,
491322
groups: int = 16,
492-
causal: bool = False,
493-
use_atomic: bool = True):
323+
causal: bool = False):
494324
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
495325
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
496326
total_flops = 3 * flops_per_qk + 2 * flops_per_v
@@ -510,7 +340,7 @@ def main(BATCH: int = 1,
510340
dO = (
511341
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
512342
device="cuda").normal_().requires_grad_())
513-
O = attention(Q, K, V, causal, groups, use_atomic)
343+
O = attention(Q, K, V, causal, groups)
514344
O.backward(dO, retain_graph=True)
515345
dQ, Q.grad = Q.grad.clone(), None
516346
dK, K.grad = K.grad.clone(), None
@@ -553,20 +383,6 @@ def run1():
553383
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
554384
parser.add_argument('--causal', action='store_true', help='Causal flag')
555385
parser.add_argument('--groups', type=int, default=16, help='groups')
556-
parser.add_argument(
557-
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
558-
parser.add_argument(
559-
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
560386
args = parser.parse_args()
561387

562-
# Handle backward compatibility and logic
563-
if args.use_split:
564-
use_atomic = False
565-
elif args.use_atomic:
566-
use_atomic = True
567-
else:
568-
# Default: use atomic
569-
use_atomic = True
570-
571-
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
572-
use_atomic)
388+
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)

0 commit comments

Comments
 (0)