Skip to content

Commit 592ddd9

Browse files
committed
Merge remote-tracking branch 'origin/main' into add-pre-commit-config
2 parents a6d59fc + a13cde2 commit 592ddd9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2943
-173
lines changed

.clang-tidy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Checks: >
4646
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
4747
-clang-analyzer-deadcode.DeadStores,
4848
-clang-analyzer-optin.cplusplus.VirtualCall,
49+
-clang-diagnostic-tautological-constant-compare,
4950
5051
WarningsAsErrors: '*'
5152

.github/workflows/amd_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,4 @@ jobs:
119119
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
120120
cd testing/python/amd
121121
unset PYTHONPATH
122-
python -m pytest -v test_tilelang_test_amd.py
122+
python -m pytest -v --cache-clear test_tilelang_test_amd.py

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,11 @@ jobs:
115115
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
116116
cd examples
117117
unset PYTHONPATH
118-
python -m pytest -n 4 **/test*.py -v -r fE --durations=0
118+
python -m pytest -n 4 **/test*.py -v -r fE --durations=0 --cache-clear
119119
120120
- name: Run tests
121121
run: |
122122
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
123123
cd testing/python
124124
unset PYTHONPATH
125-
python -m pytest -n 4 -v -r fE --durations=0 --timeout=3600
125+
python -m pytest -n 4 -v -r fE --durations=0 --cache-clear --timeout=3600

.github/workflows/metal_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ jobs:
9292
run: |
9393
cd testing/python
9494
unset PYTHONPATH
95-
python -m pytest -k metal -v -r fE --durations=0 --timeout=3600
95+
python -m pytest -k metal -v -r fE --durations=0 --cache-clear --timeout=3600

examples/deepseek_v32/sparse_mla_bwd.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,14 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c
333333

334334
def test_sparse_mla_bwd(B=1,
335335
S=4096,
336-
SKV=32768,
336+
SKV=8192,
337337
H=64,
338338
HKV=1,
339339
DQKV=576,
340340
DV=512,
341341
topk=2048,
342-
dtype=torch.bfloat16):
342+
dtype=torch.bfloat16,
343+
check_correctness=True):
343344
# Prepare data
344345
q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
345346
kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
@@ -359,7 +360,7 @@ def test_sparse_mla_bwd(B=1,
359360
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
360361
ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None)
361362

362-
if SKV <= 4096:
363+
if check_correctness:
363364
assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
364365
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
365366
print("assert_tensors_similar passed")
@@ -385,4 +386,13 @@ def fn():
385386

386387
if __name__ == "__main__":
387388
test_sparse_mla_bwd(
388-
B=1, S=4096, SKV=4096, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16)
389+
B=1,
390+
S=4096,
391+
SKV=8192,
392+
H=64,
393+
HKV=1,
394+
DQKV=576,
395+
DV=512,
396+
topk=2048,
397+
dtype=torch.bfloat16,
398+
check_correctness=True)

examples/deepseek_v32/sparse_mla_fwd.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
234234

235235
def test_sparse_mla_fwd(B=1,
236236
S=4096,
237-
SKV=4096,
237+
SKV=8192,
238238
H=128,
239239
HKV=1,
240240
DQK=576,
241241
DV=512,
242242
topk=2048,
243-
dtype=torch.bfloat16):
243+
dtype=torch.bfloat16,
244+
check_correctness=True):
244245
torch.random.manual_seed(0)
245246
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
246247
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
@@ -254,7 +255,7 @@ def test_sparse_mla_fwd(B=1,
254255

255256
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
256257

257-
if SKV <= 4096:
258+
if check_correctness:
258259
# otherwise may cause out of memory
259260
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices)
260261
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
@@ -277,4 +278,13 @@ def fn():
277278

278279
if __name__ == "__main__":
279280
test_sparse_mla_fwd(
280-
B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16)
281+
B=1,
282+
S=4096,
283+
SKV=4096,
284+
H=128,
285+
HKV=1,
286+
DQK=576,
287+
DV=512,
288+
topk=2048,
289+
dtype=torch.bfloat16,
290+
check_correctness=True)

examples/deepseek_v32/sparse_mla_fwd_pipelined.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,15 @@ def ref_sparse_mla_fwd_interface(q,
399399

400400
def test_sparse_mla_fwd_pipelined(B=1,
401401
S=4096,
402-
SKV=4096,
402+
SKV=8192,
403403
H=128,
404404
HKV=1,
405405
DQK=576,
406406
DV=512,
407407
topk=2048,
408408
dtype=torch.bfloat16,
409-
q_start_s_index=1024):
409+
q_start_s_index=1024,
410+
check_correctness=True):
410411
KV_stride = 1
411412

412413
torch.random.manual_seed(0)
@@ -456,8 +457,8 @@ def fn():
456457
parser.add_argument("--test_correctness", action="store_true")
457458
args = parser.parse_args()
458459
if args.test_correctness:
459-
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
460+
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
460461
else:
461462
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
462-
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
463-
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
463+
test_sparse_mla_fwd_pipelined(
464+
B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness)

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,23 @@ def test_example_fp8_lighting_indexer():
2020
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2121
def test_example_sparse_mla_fwd():
2222
# small shapes for testing
23-
test_sparse_mla_fwd(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256)
23+
test_sparse_mla_fwd(
24+
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
2425

2526

2627
@tilelang.testing.requires_cuda
2728
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
2829
def test_example_sparse_mla_fwd_pipelined():
2930
# small shapes for testing
30-
test_sparse_mla_fwd_pipelined(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256)
31+
test_sparse_mla_fwd_pipelined(
32+
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
3133

3234

3335
@tilelang.testing.requires_cuda
3436
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3537
def test_example_sparse_mla_bwd():
36-
test_sparse_mla_bwd()
38+
test_sparse_mla_bwd(
39+
S=1024, SKV=2048, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
3740

3841

3942
if __name__ == "__main__":

examples/flash_attention/test_example_flash_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ def test_example_gqa_bwd_wgmma_pipelined():
2727

2828
@tilelang.testing.requires_cuda
2929
def test_example_mha_bwd():
30-
example_mha_bwd.main()
30+
example_mha_bwd.main(BATCH=1)
3131

3232

3333
@tilelang.testing.requires_cuda
3434
def test_example_mha_bwd_bhsd():
35-
example_mha_bwd_bhsd.main()
35+
example_mha_bwd_bhsd.main(BATCH=1)
3636

3737

3838
@tilelang.testing.requires_cuda
3939
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
4040
def test_example_mha_bwd_wgmma_pipelined():
41-
example_mha_bwd_wgmma_pipelined.main()
41+
example_mha_bwd_wgmma_pipelined.main(BATCH=1)
4242

4343

4444
@tilelang.testing.requires_cuda
@@ -66,12 +66,12 @@ def test_example_mha_fwd_bhsd():
6666
@tilelang.testing.requires_cuda
6767
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
6868
def test_example_mha_fwd_bshd_wgmma_pipelined():
69-
example_mha_fwd_bshd_wgmma_pipelined.main()
69+
example_mha_fwd_bshd_wgmma_pipelined.main(batch=1, heads=32, seq_len=256)
7070

7171

7272
@tilelang.testing.requires_cuda
7373
def test_example_mha_fwd_bshd():
74-
example_mha_fwd_bshd.main()
74+
example_mha_fwd_bshd.main(batch=1, seq_len=256)
7575

7676

7777
@tilelang.testing.requires_cuda

examples/norm/test_rms_norm.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,9 @@ def ref_program(x):
6363
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12)
6464

6565

66-
def test_rms_norm():
67-
M, N, blk_m = 8192, 8192, 1
66+
def test_rms_norm(M=1024, N=1024, blk_m=1):
6867
program = rms_norm(M, N, blk_m)
69-
kernel = tilelang.compile(
70-
program,
71-
out_idx=-1,
72-
target="cuda",
73-
execution_backend="cython",
74-
pass_configs={"tl.disable_tma_lower": True})
68+
kernel = tilelang.compile(program, out_idx=-1, pass_configs={"tl.disable_tma_lower": True})
7569
profiler = kernel.get_profiler()
7670
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
7771

0 commit comments

Comments
 (0)