Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ jobs:
uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=2 \
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
../examples

# NVIDIA CUDA tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,27 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():


def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main()
example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)


def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main()
example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions examples/cast/example_group_per_split_token_cast_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
return x_fp8


def main():
M, N, BG, blk_m = 8192, 8192, 2, 8
def main(M=8192, N=8192, BG=2, blk_m=8):
if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16":
Expand Down
3 changes: 1 addition & 2 deletions examples/cast/example_per_token_cast_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return x_fp8, (x_amax / 448.0).view(m, -1)


def main():
M, N, blk_m = 8192, 8192, 8
def main(M=8192, N=8192, blk_m=8):
kernel = per_token_cast_to_fp8(M, N, blk_m)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
Expand Down
4 changes: 2 additions & 2 deletions examples/cast/test_example_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@


def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main()
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8)


def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main()
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_example_topk_selector():


def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer()
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1)


@tilelang.testing.requires_cuda
Expand Down
3 changes: 1 addition & 2 deletions examples/dynamic_shape/example_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def ref_program(A, B):
print(f"Latency: {latency} ms")


def main():
M, N, K = 16384, 16384, 16384
def main(M=16384, N=16384, K=16384):
block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False
in_dtype, out_dtype = "float16", "float16"
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamic_shape/test_example_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def test_example_dynamic():
example_dynamic.main()
example_dynamic.main(M=1024, N=1024, K=1024)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions examples/flash_attention/test_example_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def test_example_mha_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_fwd_bshd_wgmma_pipelined():
example_gqa_fwd_bshd_wgmma_pipelined.main()
example_gqa_fwd_bshd_wgmma_pipelined.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)


@tilelang.testing.requires_cuda
def test_example_gqa_fwd_bshd():
example_gqa_fwd_bshd.main()
example_gqa_fwd_bshd.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)


@tilelang.testing.requires_cuda
Expand Down
4 changes: 2 additions & 2 deletions testing/python/issue/test_tilelang_issue_96.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32):


def test_pipeline_large_matrix():
"""Test pipeline stages with large matrix multiplication (8192x8192)"""
run_gemm_pipeline_test(8192)
"""Test pipeline stages with large matrix multiplication (4096x4096)"""
run_gemm_pipeline_test(4096)


def test_pipeline_small_matrix():
Expand Down