Skip to content

Commit 14deea8

Browse files
committed
[Lint]
1 parent a69df37 commit 14deea8

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
torch.manual_seed(1)
1212

1313

14-
1514
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
1615
assert mode in ["full", "random", "third"]
1716
if mode == "full":
@@ -369,16 +368,19 @@ def flash_bwd(
369368
T.clear(dq)
370369
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
371370
T.atomic_add(
372-
dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :],
371+
dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N,
372+
bx, :],
373373
dq,
374374
memory_order="release")
375375

376376
T.atomic_add(
377-
dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :],
377+
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
378+
bx // groups, :],
378379
dv,
379380
memory_order="release")
380381
T.atomic_add(
381-
dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :],
382+
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
383+
bx // groups, :],
382384
dk,
383385
memory_order="release")
384386

0 commit comments

Comments
 (0)