Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7ca3299
[feat] add an example mma atom
botbw Jul 18, 2025
e309fc6
[fix] fix typo naming
botbw Jul 21, 2025
6ab6423
[feat] add a template to enable compilation
botbw Jul 23, 2025
e9c6bcb
[feat] add print util
botbw Jul 24, 2025
680d5f1
[WIP] pass on single block tile
botbw Jul 26, 2025
5e1e047
[feat] add sm80 metadata layout
botbw Aug 4, 2025
0827b0c
[chore] clean codebase
botbw Aug 4, 2025
faa5467
[CI] format.sh
botbw Aug 4, 2025
de1d2d6
[feat] add sm80 compress utils
botbw Aug 4, 2025
c566c50
[bugfix] fix C fragment layout
botbw Aug 7, 2025
60ca0e2
[refactor] use nvcc version instead of str
botbw Aug 7, 2025
d989dbe
[test] add test cases
botbw Aug 7, 2025
f7f920f
[chore] add a param check
botbw Aug 7, 2025
02396c1
[chore] format a bit
botbw Aug 7, 2025
a571f63
[chore] rename func to satisfy PEP 8 and appease gemini
botbw Aug 7, 2025
1df63f1
[chore] add check
botbw Aug 11, 2025
3eee2e0
[feat] support sm75 layout && add assertion && chore
botbw Aug 13, 2025
df621cc
[bug] fix illegal memory access when using two warps over N=32
botbw Sep 11, 2025
e00a3e3
[chore] add example
botbw Sep 11, 2025
636352e
[chore] format
botbw Sep 11, 2025
1531762
[example] update benchmark
botbw Sep 12, 2025
c52aa3c
[bugfix] fix namespace and format
botbw Sep 12, 2025
0e3bd60
Merge remote-tracking branch 'upstream/main' into gemm_sp_sm80
botbw Sep 12, 2025
fecadc8
[bugfix] fix incorrect param passing
botbw Sep 12, 2025
51655d7
[refactor] update variable declaration for clarity in gemm_layouts an…
LeiWang1999 Sep 15, 2025
333e4d4
Merge branch 'main' of https://github.com/tile-ai/tilelang into gemm_…
LeiWang1999 Sep 15, 2025
122bff0
[Cleanup] Remove unnecessary blank lines in metadata layout functions…
LeiWang1999 Sep 15, 2025
b8e195f
[CI] fix arch
botbw Sep 15, 2025
de09434
[example] add torch sparse benchmark
botbw Sep 15, 2025
9962ec5
[misc] polish && add reference && apply review suggestionsi && format
botbw Sep 15, 2025
88435da
[CI] format with clang-tidy
botbw Sep 15, 2025
597b66f
[Cleanup] Format and align template struct definitions in half.hpp, c…
LeiWang1999 Sep 15, 2025
27b4491
[Update] Modify CUDA version requirements in test_gemm_sp_sm80 and ma…
LeiWang1999 Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions benchmark/matmul/benchmark_matmul_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
import torch
from triton.testing import do_bench

import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.contrib import nvcc
from tilelang.layout import make_metadata_layout

# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

arch = nvcc.get_target_compute_version()

ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}


def ref_program(A, B):
"""
Expand Down Expand Up @@ -79,11 +86,11 @@ def get_configs(M, N, K):
return configs


def matmul_sp(M, N, K):
def matmul_sp(M, N, K, accum_dtype):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- B: (K, N)
- C: (M, N)

Parameters
Expand Down Expand Up @@ -155,14 +162,14 @@ def kernel(
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
e_factor, e_dtype = ARCH_INFO[arch]

@T.prim_func
def main(
A_sparse: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, K // 8), 'uint8'),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
Expand All @@ -182,13 +189,13 @@ def main(
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8')
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)

# Clear out the accumulation buffer
T.clear(C_local)
Expand All @@ -198,32 +205,27 @@ def main(
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass",
block_k=block_K),
E, mma_dtype="float16", backend="cutlass", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="sm90",
backend="cutlass",
block_k=block_K),
E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
})
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A_sparse[by * block_M, k * block_K], A_shared)
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
# Load a sub-block of E from global memory into E_shared
T.copy(E[by * block_M, k * block_K // 8], E_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
# C_local += A_shared @ B_shared
T.gemm_sp(
A_shared,
E_shared,
B_shared,
C_local,
transpose_B=True,
transpose_B=False,
policy=policy,
)
# Write back the results from C_local to the global memory C
Expand All @@ -241,24 +243,53 @@ def main(
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--disable_cache", action="store_true")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument(
"--bench_torch_sparse",
type=str,
choices=['cutlass', 'cusparselt'],
default=None,
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
)
args = parser.parse_args()

if args.disable_cache:
tilelang.disable_cache()

M, N, K = args.m, args.n, args.k

# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K

# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K)
best_result = matmul_sp(M, N, K, args.accum_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
ref_latency = do_bench(lambda: A @ B.T)
B = torch.randn(K, N, dtype=torch.float16, device="cuda")
ref_latency = do_bench(lambda: A @ B)

if args.bench_torch_sparse is not None:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
if args.bench_torch_sparse == 'cutlass':
SparseSemiStructuredTensor._FORCE_CUTLASS = True
A_sp = to_sparse_semi_structured(A, transposed=False)
torch_sparse_latency = do_bench(lambda: A_sp @ B)

# Print out the benchmark results
print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}")

print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
if args.bench_torch_sparse is not None:
print(
f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
)

print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")
160 changes: 160 additions & 0 deletions examples/gemm_sp/example_gemm_sp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse

import tilelang
import tilelang.language as T

from tilelang.layout import make_metadata_layout
from tilelang.utils.sparse import compress
from tilelang.contrib import nvcc
from triton.testing import do_bench

import torch

arch = nvcc.get_target_compute_version()

ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}

default_config = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}


def generate_sparse_tensor(M: int, K: int, dtype=torch.float16, device='cuda'):
elem, group = 2, 4
full_tensor = torch.randn((M, K), dtype=dtype, device=device).view(M, -1, group)
indice = full_tensor.topk(elem, dim=-1).indices
full_tensor.scatter_(-1, indice, 0)
return full_tensor.view(M, K)


@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
enable_rasterization):
e_factor, e_dtype = ARCH_INFO[arch]

@T.prim_func
def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), 'float16'),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
backend="cutlass",
block_k=block_K,
arch=arch),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)

T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])

return gemm_sp_fp16


def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**default_config[args.cfg][args.accum_dtype])

a = generate_sparse_tensor(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)

a_sparse, e = compress(
a,
transposed=False,
block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
c = kernel(a_sparse, e, b)

ref_c = a @ b

assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}")

latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)

total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def main(
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K),
E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="sm90",
arch="9.0",
backend="cutlass",
block_k=block_K),
})
Expand Down
Loading
Loading