Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 131 additions & 77 deletions examples/gemv/example_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,75 +216,122 @@ def main(
return main


def get_best_config(N, K):

def get_configs():
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
return [
dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())
]

@autotune(
configs=get_configs(),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
target="auto",
)
def kernel(
BLOCK_N=None,
reduce_threads=None,
def get_block_template_configs():
iter_params = dict(
block_M=[2, 4, 8, 32, 64, 128],
block_N=[2, 4, 8, 32, 64, 128],
num_stages=[0, 1, 2, 3, 4],
threads=[32, 64, 128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]


@tl.autotune(
configs=get_block_template_configs(),
warmup=3,
rep=20,
)
@tl.jit(
pass_configs={
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
out_idx=[2],
)
def gemv_alloc_reducer(M,
N,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
accum_dtype: str = "float"):

@T.prim_func
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M,
dtype)): # type: ignore
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all")
T.clear(o_reducer)
for i0_n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages):
a_smem = T.alloc_shared((block_M, block_N), dtype)
T.copy(a[i0_m * block_M, i0_n * block_N], a_smem)
a_frag = T.alloc_fragment((block_M, block_N), dtype)
T.copy(a_smem, a_frag)
x_frag = T.alloc_fragment(block_N, dtype)
T.copy(x[i0_n * block_N], x_frag)
for i1_m, i1_n in T.Parallel(block_M, block_N):
o_reducer[i1_m] += a_frag[i1_m, i1_n] * x_frag[i1_n]
T.finalize_reducer(o_reducer)
T.copy(o_reducer, o[i0_m * block_M])

return main


def get_thread_template_configs():
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]


@autotune(
configs=get_thread_template_configs(),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
target="auto",
)
def get_autotuned_kernel(
N,
K,
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K

@T.prim_func
def main(
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K

@T.prim_func
def main(
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)

T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(
accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))

C[bn * BLOCK_N + tn] = C_reduced[0]

return main

return kernel()
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)

T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))

C[bn * BLOCK_N + tn] = C_reduced[0]

return main


def check_correctness_and_bench(kernel, N, K, bench_ref=True):
Expand All @@ -297,7 +344,7 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
print(f"TileLang Latency: {latency} ms\n")


def main():
def main(do_bench: bool = True):
parser = argparse.ArgumentParser(description="GEMV Example")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
Expand All @@ -308,16 +355,23 @@ def main():
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K)
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K)

print("Test passed!")

best_result = get_best_config(N, K)
best_config = best_result.config
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
profiler = kernel.get_profiler()
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
print(f"Torch Latency: {latency} ms")
latency = profiler.do_bench(kernel, warmup=500)
print(f"TileLang Latency: {latency} ms\n")
if not do_bench:
best_result = get_autotuned_kernel(N, K)
best_config = best_result.config
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
profiler = kernel.get_profiler()
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
print(f"Torch Latency: {latency} ms")
tilelang_thread_latency = profiler.do_bench(kernel, warmup=500)
print(f"TileLang SIMT Latency: {tilelang_thread_latency} ms\n")
kernel = gemv_alloc_reducer(N, K)
profiler = kernel.get_profiler()
tilelang_tile_latency = profiler.do_bench(kernel, warmup=500)
print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/gemv/test_example_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_example_gemv():
example_gemv.main()
example_gemv.main(do_bench=False)


if __name__ == "__main__":
Expand Down