Skip to content

Commit 481cae4

Browse files
[Example] Revert the atomic/split&sum templates in MHA backward examples (#943)
* revert split+sum template for MHA backward * lint * Update example_mha_bwd.py * Update example_mha_bwd_wgmma_pipelined.py --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
1 parent 3aecab8 commit 481cae4

File tree

2 files changed

+40
-317
lines changed

2 files changed

+40
-317
lines changed

examples/flash_attention/example_mha_bwd.py

Lines changed: 20 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -149,110 +149,7 @@ def flash_bwd_post(
149149
@tilelang.jit(pass_configs={
150150
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
151151
})
152-
def flashattn_bwd_atomic_add(batch,
153-
heads,
154-
seq_len,
155-
dim,
156-
is_causal,
157-
block_M,
158-
block_N,
159-
threads=128,
160-
num_stages=2):
161-
sm_scale = (1.0 / dim)**0.5
162-
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
163-
shape = [batch, seq_len, heads, dim]
164-
dtype = "float16"
165-
accum_dtype = "float"
166-
167-
@T.prim_func
168-
def flash_bwd(
169-
Q: T.Tensor(shape, dtype), # type: ignore
170-
K: T.Tensor(shape, dtype), # type: ignore
171-
V: T.Tensor(shape, dtype), # type: ignore
172-
dO: T.Tensor(shape, dtype), # type: ignore
173-
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
174-
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
175-
dQ: T.Tensor(shape, accum_dtype), # type: ignore
176-
dK: T.Tensor(shape, accum_dtype), # type: ignore
177-
dV: T.Tensor(shape, accum_dtype), # type: ignore
178-
):
179-
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
180-
K_shared = T.alloc_shared([block_M, dim], dtype)
181-
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
182-
q = T.alloc_shared([block_N, dim], dtype)
183-
V_shared = T.alloc_shared([block_M, dim], dtype)
184-
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
185-
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
186-
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
187-
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
188-
lse_shared = T.alloc_shared([block_N], accum_dtype)
189-
delta = T.alloc_shared([block_N], accum_dtype)
190-
do = T.alloc_shared([block_N, dim], dtype)
191-
dv = T.alloc_fragment([block_M, dim], accum_dtype)
192-
dk = T.alloc_fragment([block_M, dim], accum_dtype)
193-
dq = T.alloc_fragment([block_N, dim], accum_dtype)
194-
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
195-
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
196-
197-
T.annotate_layout({
198-
dQ: make_dq_layout(dQ),
199-
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
200-
})
201-
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
202-
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
203-
T.clear(dv)
204-
T.clear(dk)
205-
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
206-
loop_ed = T.ceildiv(seq_len, block_N)
207-
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
208-
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
209-
T.clear(qkT)
210-
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
211-
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
212-
for i, j in T.Parallel(block_M, block_N):
213-
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
214-
if is_causal:
215-
for i, j in T.Parallel(block_M, block_N):
216-
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
217-
0)
218-
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
219-
T.clear(dsT)
220-
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
221-
T.copy(qkT, qkT_cast)
222-
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
223-
224-
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
225-
226-
for i, j in T.Parallel(block_M, block_N):
227-
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
228-
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
229-
230-
T.copy(dsT_cast, dsT_shared)
231-
T.clear(dq)
232-
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
233-
for i, j in T.Parallel(block_N, dim):
234-
if k * block_N + i < seq_len:
235-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
236-
T.copy(dv, dv_shared)
237-
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
238-
T.copy(dk, dk_shared)
239-
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)
240-
241-
return flash_bwd
242-
243-
244-
@tilelang.jit(pass_configs={
245-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
246-
})
247-
def flashattn_bwd_split(batch,
248-
heads,
249-
seq_len,
250-
dim,
251-
is_causal,
252-
block_M,
253-
block_N,
254-
threads=128,
255-
num_stages=2):
152+
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
256153
sm_scale = (1.0 / dim)**0.5
257154
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
258155
shape = [batch, seq_len, heads, dim]
@@ -271,9 +168,13 @@ def flash_bwd(
271168
dK: T.Tensor(shape, dtype), # type: ignore
272169
dV: T.Tensor(shape, dtype), # type: ignore
273170
):
274-
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
171+
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
275172
K_shared = T.alloc_shared([block_M, dim], dtype)
276173
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
174+
# should not store K to local if dim is large
175+
# K_local = T.alloc_fragment([block_M, dim], dtype)
176+
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
177+
# V_local = T.alloc_fragment([block_M, dim], dtype)
277178
q = T.alloc_shared([block_N, dim], dtype)
278179
V_shared = T.alloc_shared([block_M, dim], dtype)
279180
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
@@ -301,7 +202,7 @@ def flash_bwd(
301202
T.clear(dk)
302203
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
303204
loop_ed = T.ceildiv(seq_len, block_N)
304-
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
205+
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
305206
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
306207
T.clear(qkT)
307208
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@@ -328,8 +229,7 @@ def flash_bwd(
328229
T.clear(dq)
329230
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
330231
for i, j in T.Parallel(block_N, dim):
331-
if k * block_N + i < seq_len:
332-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
232+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
333233
T.copy(dv, dv_shared)
334234
T.copy(dk, dk_shared)
335235
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
@@ -341,14 +241,13 @@ def flash_bwd(
341241
class _attention(torch.autograd.Function):
342242

343243
@staticmethod
344-
def forward(ctx, q, k, v, causal, use_atomic=True):
244+
def forward(ctx, q, k, v, causal):
345245
BATCH, N_CTX, H, D_HEAD = q.shape
346246
block_M = 64
347247
block_N = 64 if D_HEAD <= 128 else 32
348248
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
349249
ctx.save_for_backward(q, k, v, o, lse)
350250
ctx.causal = causal
351-
ctx.use_atomic = use_atomic
352251
return o
353252

354253
@staticmethod
@@ -367,29 +266,14 @@ def maybe_contiguous(x):
367266
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
368267
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
369268
delta = kernel_prep(o, do)
370-
371-
if ctx.use_atomic:
372-
kernel = flashattn_bwd_atomic_add(
373-
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
374-
shape = [BATCH, N_CTX, H, D_HEAD]
375-
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
376-
dk = torch.zeros(shape, dtype=torch.float32, device=q.device)
377-
dv = torch.zeros(shape, dtype=torch.float32, device=q.device)
378-
kernel(q, k, v, do, lse, delta, dq, dk, dv)
379-
dq = kernel_post(dq)
380-
dk = dk.to(torch.float16)
381-
dv = dv.to(torch.float16)
382-
else:
383-
kernel = flashattn_bwd_split(
384-
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
385-
shape = [BATCH, N_CTX, H, D_HEAD]
386-
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
387-
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
388-
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
389-
kernel(q, k, v, do, lse, delta, dq, dk, dv)
390-
dq = kernel_post(dq)
391-
392-
return dq, dk, dv, None, None
269+
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
270+
shape = [BATCH, N_CTX, H, D_HEAD]
271+
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
272+
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
273+
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
274+
kernel(q, k, v, do, lse, delta, dq, dk, dv)
275+
dq = kernel_post(dq)
276+
return dq, dk, dv, None
393277

394278

395279
attention = _attention.apply
@@ -415,9 +299,7 @@ def main(
415299
N_CTX: int = 1024,
416300
D_HEAD: int = 64,
417301
causal: bool = False,
418-
use_atomic: bool = True,
419302
):
420-
print(f"Test with use_atomic: {use_atomic}")
421303
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
422304
total_flops = 5 * flops_per_matmul
423305
if causal:
@@ -428,7 +310,7 @@ def main(
428310
K = torch.empty_like(Q).normal_().requires_grad_()
429311
V = torch.empty_like(Q).normal_().requires_grad_()
430312
dO = torch.randn_like(Q)
431-
O = attention(Q, K, V, causal, use_atomic)
313+
O = attention(Q, K, V, causal)
432314
O.backward(dO, retain_graph=True)
433315
dQ, Q.grad = Q.grad.clone(), None
434316
dK, K.grad = K.grad.clone(), None
@@ -444,7 +326,6 @@ def main(
444326
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
445327
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
446328
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
447-
print('All checks passed.✅')
448329

449330
def run():
450331
O_ref.backward(dO, retain_graph=True)
@@ -468,20 +349,6 @@ def run1():
468349
parser.add_argument('--h', type=int, default=32, help='Number of heads')
469350
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
470351
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
471-
parser.add_argument('--causal', action='store_true', help='Causal flag')
472-
parser.add_argument(
473-
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
474-
parser.add_argument(
475-
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
352+
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
476353
args = parser.parse_args()
477-
478-
# Handle backward compatibility and logic
479-
if args.use_split:
480-
use_atomic = False
481-
elif args.use_atomic:
482-
use_atomic = True
483-
else:
484-
# Default: use atomic
485-
use_atomic = True
486-
487-
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic)
354+
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)

0 commit comments

Comments
 (0)