|  | 
| 7 | 7 | from einops import rearrange, repeat | 
| 8 | 8 | from bert_padding import pad_input, unpad_input | 
| 9 | 9 | 
 | 
| 10 |  | -torch.manual_seed(1) | 
| 11 |  | - | 
| 12 | 10 | 
 | 
| 13 | 11 | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): | 
| 14 | 12 |     assert mode in ["full", "random", "third"] | 
| @@ -525,7 +523,10 @@ def flash_bwd( | 
| 525 | 523 |                 T.gemm(dsT_shared, K_shared, dq, transpose_A=True) | 
| 526 | 524 |                 for i, j in T.Parallel(block_N, dim_qk): | 
| 527 | 525 |                     if k_base * block_N + i < q_current_seqlen: | 
| 528 |  | -                        T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j]) | 
|  | 526 | +                        T.atomic_add( | 
|  | 527 | +                            dQ[q_start_idx + k_base * block_N + i, bx, j], | 
|  | 528 | +                            dq[i, j], | 
|  | 529 | +                            memory_order="release") | 
| 529 | 530 | 
 | 
| 530 | 531 |             T.copy(dv, dv_shared) | 
| 531 | 532 |             for i, d in T.Parallel(block_M, dim_v): | 
| @@ -739,9 +740,9 @@ def main(BATCH: int = 1, | 
| 739 | 740 |     dV_ref, V.grad = V.grad.clone(), None | 
| 740 | 741 | 
 | 
| 741 | 742 |     torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2) | 
| 742 |  | -    torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) | 
| 743 | 743 |     torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) | 
| 744 | 744 |     torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) | 
|  | 745 | +    torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) | 
| 745 | 746 |     print('All checks passed.✅') | 
| 746 | 747 | 
 | 
| 747 | 748 |     def run(): | 
| @@ -784,8 +785,8 @@ def run1(): | 
| 784 | 785 |     elif args.use_atomic: | 
| 785 | 786 |         use_atomic = True | 
| 786 | 787 |     else: | 
| 787 |  | -        # Default: use atomic | 
| 788 |  | -        use_atomic = True | 
|  | 788 | +        # Default: use split | 
|  | 789 | +        use_atomic = False | 
| 789 | 790 | 
 | 
| 790 | 791 |     main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, | 
| 791 | 792 |          use_atomic) | 
0 commit comments