Skip to content

Commit 9a4a359

Browse files
committed
lint fix
1 parent 0bfc49a commit 9a4a359

File tree

8 files changed

+174
-47
lines changed

8 files changed

+174
-47
lines changed

examples/flash_attention/example_gqa_bwd.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,17 @@ def flash_bwd_post(
147147
@tilelang.jit(pass_configs={
148148
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
149149
})
150-
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
150+
def flashattn_bwd_atomic_add(batch,
151+
heads,
152+
seq_len,
153+
dim_qk,
154+
dim_v,
155+
is_causal,
156+
block_M,
157+
block_N,
158+
threads=256,
159+
num_stages=2,
160+
groups=1):
151161
sm_scale = (1.0 / dim_qk)**0.5
152162
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
153163
head_kv = heads // groups
@@ -228,17 +238,27 @@ def flash_bwd(
228238
if k * block_N + i < seq_len:
229239
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
230240
T.copy(dv, dv_shared)
231-
T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx // groups, :], dv_shared)
241+
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
232242
T.copy(dk, dk_shared)
233-
T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx // groups, :], dk_shared)
243+
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)
234244

235245
return flash_bwd
236246

237247

238248
@tilelang.jit(pass_configs={
239249
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
240250
})
241-
def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
251+
def flashattn_bwd_split(batch,
252+
heads,
253+
seq_len,
254+
dim_qk,
255+
dim_v,
256+
is_causal,
257+
block_M,
258+
block_N,
259+
threads=256,
260+
num_stages=2,
261+
groups=1):
242262
sm_scale = (1.0 / dim_qk)**0.5
243263
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
244264
head_kv = heads // groups
@@ -367,8 +387,18 @@ def maybe_contiguous(x):
367387
delta = mod_prep(o, do)
368388

369389
if ctx.use_atomic:
370-
kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
371-
threads=256, num_stages=2, groups=groups)
390+
kernel = flashattn_bwd_atomic_add(
391+
BATCH,
392+
H,
393+
N_CTX,
394+
D_HEAD_QK,
395+
D_HEAD_V,
396+
ctx.causal,
397+
block_M,
398+
block_N,
399+
threads=256,
400+
num_stages=2,
401+
groups=groups)
372402
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
373403
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
374404
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
@@ -380,8 +410,18 @@ def maybe_contiguous(x):
380410
dk = dk.to(torch.float16)
381411
dv = dv.to(torch.float16)
382412
else:
383-
kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
384-
threads=256, num_stages=2, groups=groups)
413+
kernel = flashattn_bwd_split(
414+
BATCH,
415+
H,
416+
N_CTX,
417+
D_HEAD_QK,
418+
D_HEAD_V,
419+
ctx.causal,
420+
block_M,
421+
block_N,
422+
threads=256,
423+
num_stages=2,
424+
groups=groups)
385425
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
386426
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
387427
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
@@ -493,8 +533,10 @@ def run1():
493533
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
494534
parser.add_argument('--causal', action='store_true', help='Causal flag')
495535
parser.add_argument('--groups', type=int, default=16, help='groups')
496-
parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
497-
parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV')
536+
parser.add_argument(
537+
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
538+
parser.add_argument(
539+
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
498540
args = parser.parse_args()
499541

500542
# Handle backward compatibility and logic
@@ -506,4 +548,5 @@ def run1():
506548
# Default: use atomic
507549
use_atomic = True
508550

509-
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
551+
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
552+
use_atomic)

examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,17 @@ def flash_bwd_post(
147147
@tilelang.jit(pass_configs={
148148
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
149149
})
150-
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
150+
def flashattn_bwd_atomic_add(batch,
151+
heads,
152+
seq_len,
153+
dim_qk,
154+
dim_v,
155+
is_causal,
156+
block_M,
157+
block_N,
158+
threads=256,
159+
num_stages=2,
160+
groups=1):
151161
sm_scale = (1.0 / dim_qk)**0.5
152162
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
153163
head_kv = heads // groups
@@ -238,7 +248,7 @@ def flash_bwd(
238248
if k * block_N + i < seq_len:
239249
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
240250
T.copy(dv, dv_shared)
241-
T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx // groups, :], dv_shared)
251+
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
242252
T.copy(dk, dk_shared)
243253
for i, j in T.Parallel(block_M, dim_qk):
244254
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
@@ -249,7 +259,17 @@ def flash_bwd(
249259
@tilelang.jit(pass_configs={
250260
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
251261
})
252-
def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
262+
def flashattn_bwd_split(batch,
263+
heads,
264+
seq_len,
265+
dim_qk,
266+
dim_v,
267+
is_causal,
268+
block_M,
269+
block_N,
270+
threads=256,
271+
num_stages=2,
272+
groups=1):
253273
sm_scale = (1.0 / dim_qk)**0.5
254274
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
255275
head_kv = heads // groups
@@ -389,8 +409,18 @@ def maybe_contiguous(x):
389409
delta = mod_prep(o, do)
390410

391411
if ctx.use_atomic:
392-
kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
393-
threads=256, num_stages=2, groups=groups)
412+
kernel = flashattn_bwd_atomic_add(
413+
BATCH,
414+
H,
415+
N_CTX,
416+
D_HEAD_QK,
417+
D_HEAD_V,
418+
ctx.causal,
419+
block_M,
420+
block_N,
421+
threads=256,
422+
num_stages=2,
423+
groups=groups)
394424
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
395425
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
396426
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
@@ -402,8 +432,18 @@ def maybe_contiguous(x):
402432
dk = dk.to(torch.float16)
403433
dv = dv.to(torch.float16)
404434
else:
405-
kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
406-
threads=256, num_stages=2, groups=groups)
435+
kernel = flashattn_bwd_split(
436+
BATCH,
437+
H,
438+
N_CTX,
439+
D_HEAD_QK,
440+
D_HEAD_V,
441+
ctx.causal,
442+
block_M,
443+
block_N,
444+
threads=256,
445+
num_stages=2,
446+
groups=groups)
407447
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
408448
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
409449
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
@@ -515,8 +555,10 @@ def run1():
515555
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
516556
parser.add_argument('--causal', action='store_true', help='Causal flag')
517557
parser.add_argument('--groups', type=int, default=16, help='groups')
518-
parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
519-
parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV')
558+
parser.add_argument(
559+
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
560+
parser.add_argument(
561+
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
520562
args = parser.parse_args()
521563

522564
# Handle backward compatibility and logic
@@ -528,4 +570,5 @@ def run1():
528570
# Default: use atomic
529571
use_atomic = True
530572

531-
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
573+
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
574+
use_atomic)

examples/flash_attention/example_mha_bwd.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,15 @@ def flash_bwd_post(
149149
@tilelang.jit(pass_configs={
150150
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
151151
})
152-
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2):
152+
def flashattn_bwd_atomic_add(batch,
153+
heads,
154+
seq_len,
155+
dim,
156+
is_causal,
157+
block_M,
158+
block_N,
159+
threads=128,
160+
num_stages=2):
153161
sm_scale = (1.0 / dim)**0.5
154162
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
155163
shape = [batch, seq_len, heads, dim]
@@ -226,17 +234,25 @@ def flash_bwd(
226234
if k * block_N + i < seq_len:
227235
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
228236
T.copy(dv, dv_shared)
229-
T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx, :], dv_shared)
237+
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
230238
T.copy(dk, dk_shared)
231-
T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx, :], dk_shared)
239+
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)
232240

233241
return flash_bwd
234242

235243

236244
@tilelang.jit(pass_configs={
237245
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
238246
})
239-
def flashattn_bwd_split(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2):
247+
def flashattn_bwd_split(batch,
248+
heads,
249+
seq_len,
250+
dim,
251+
is_causal,
252+
block_M,
253+
block_N,
254+
threads=128,
255+
num_stages=2):
240256
sm_scale = (1.0 / dim)**0.5
241257
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
242258
shape = [batch, seq_len, heads, dim]
@@ -353,8 +369,8 @@ def maybe_contiguous(x):
353369
delta = kernel_prep(o, do)
354370

355371
if ctx.use_atomic:
356-
kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N,
357-
threads=128, num_stages=2)
372+
kernel = flashattn_bwd_atomic_add(
373+
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
358374
shape = [BATCH, N_CTX, H, D_HEAD]
359375
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
360376
dk = torch.zeros(shape, dtype=torch.float32, device=q.device)
@@ -364,8 +380,8 @@ def maybe_contiguous(x):
364380
dk = dk.to(torch.float16)
365381
dv = dv.to(torch.float16)
366382
else:
367-
kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N,
368-
threads=128, num_stages=2)
383+
kernel = flashattn_bwd_split(
384+
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
369385
shape = [BATCH, N_CTX, H, D_HEAD]
370386
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
371387
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
@@ -453,8 +469,10 @@ def run1():
453469
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
454470
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
455471
parser.add_argument('--causal', action='store_true', help='Causal flag')
456-
parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
457-
parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV')
472+
parser.add_argument(
473+
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
474+
parser.add_argument(
475+
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
458476
args = parser.parse_args()
459477

460478
# Handle backward compatibility and logic

examples/flash_attention/example_mha_bwd_wgmma_pipelined.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,15 @@ def flash_bwd_post(
146146
@tilelang.jit(pass_configs={
147147
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
148148
})
149-
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=256, num_stages=2):
149+
def flashattn_bwd_atomic_add(batch,
150+
heads,
151+
seq_len,
152+
dim,
153+
is_causal,
154+
block_M,
155+
block_N,
156+
threads=256,
157+
num_stages=2):
150158
sm_scale = (1.0 / dim)**0.5
151159
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
152160
shape = [batch, seq_len, heads, dim]
@@ -234,17 +242,25 @@ def flash_bwd(
234242
if k * block_N + i < seq_len:
235243
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
236244
T.copy(dv, dv_shared)
237-
T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx, :], dv_shared)
245+
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
238246
T.copy(dk, dk_shared)
239-
T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx, :], dk_shared)
247+
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)
240248

241249
return flash_bwd
242250

243251

244252
@tilelang.jit(pass_configs={
245253
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
246254
})
247-
def flashattn_bwd_split(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=256, num_stages=2):
255+
def flashattn_bwd_split(batch,
256+
heads,
257+
seq_len,
258+
dim,
259+
is_causal,
260+
block_M,
261+
block_N,
262+
threads=256,
263+
num_stages=2):
248264
sm_scale = (1.0 / dim)**0.5
249265
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
250266
shape = [batch, seq_len, heads, dim]
@@ -374,8 +390,8 @@ def maybe_contiguous(x):
374390
delta = mod_prep(o, do)
375391

376392
if ctx.use_atomic:
377-
mod = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N,
378-
threads=256, num_stages=2)
393+
mod = flashattn_bwd_atomic_add(
394+
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2)
379395
shape = [BATCH, N_CTX, H, D_HEAD]
380396
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
381397
dk = torch.zeros(shape, dtype=torch.float32, device=q.device)
@@ -385,8 +401,8 @@ def maybe_contiguous(x):
385401
dk = dk.to(torch.float16)
386402
dv = dv.to(torch.float16)
387403
else:
388-
mod = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N,
389-
threads=256, num_stages=2)
404+
mod = flashattn_bwd_split(
405+
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2)
390406
shape = [BATCH, N_CTX, H, D_HEAD]
391407
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
392408
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
@@ -474,8 +490,10 @@ def run1():
474490
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
475491
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
476492
parser.add_argument('--causal', action='store_true', help='Causal flag')
477-
parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
478-
parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV')
493+
parser.add_argument(
494+
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
495+
parser.add_argument(
496+
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
479497
args = parser.parse_args()
480498

481499
# Handle backward compatibility and logic

0 commit comments

Comments
 (0)