Skip to content

Commit 853f9c3

Browse files
authored
[BugFix] Add memory order and testing script for split version GQA bwd kernel (#1100)
* [BugFix] Add memory order for split version kernel; Remove torch manual seed * [Lint] Manual
1 parent 4c9da81 commit 853f9c3

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from einops import rearrange, repeat
88
from bert_padding import pad_input, unpad_input
99

10-
torch.manual_seed(1)
11-
1210

1311
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
1412
assert mode in ["full", "random", "third"]
@@ -525,7 +523,10 @@ def flash_bwd(
525523
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
526524
for i, j in T.Parallel(block_N, dim_qk):
527525
if k_base * block_N + i < q_current_seqlen:
528-
T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j])
526+
T.atomic_add(
527+
dQ[q_start_idx + k_base * block_N + i, bx, j],
528+
dq[i, j],
529+
memory_order="release")
529530

530531
T.copy(dv, dv_shared)
531532
for i, d in T.Parallel(block_M, dim_v):
@@ -739,9 +740,9 @@ def main(BATCH: int = 1,
739740
dV_ref, V.grad = V.grad.clone(), None
740741

741742
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
742-
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
743743
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
744744
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
745+
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
745746
print('All checks passed.✅')
746747

747748
def run():
@@ -784,8 +785,8 @@ def run1():
784785
elif args.use_atomic:
785786
use_atomic = True
786787
else:
787-
# Default: use atomic
788-
use_atomic = True
788+
# Default: use split
789+
use_atomic = False
789790

790791
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
791792
use_atomic)

examples/flash_attention/test_example_flash_attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
import example_mha_fwd_varlen
1313
import example_mha_bwd_wgmma_pipelined
1414
import example_mha_fwd_bhsd
15+
import example_gqa_bwd_tma_reduce_varlen
16+
17+
18+
@tilelang.testing.requires_cuda
19+
def test_example_gqa_bwd_tma_reduce_varlen():
20+
example_gqa_bwd_tma_reduce_varlen.main()
1521

1622

1723
@tilelang.testing.requires_cuda

0 commit comments

Comments
 (0)