Skip to content

Commit 514bdea

Browse files
authored
[Example] Add block level high performance gemv example (#1097)
* add alloc_reducer gemv example * test
1 parent f003f37 commit 514bdea

File tree

2 files changed

+132
-78
lines changed

2 files changed

+132
-78
lines changed

examples/gemv/example_gemv.py

Lines changed: 131 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -216,75 +216,122 @@ def main(
216216
return main
217217

218218

219-
def get_best_config(N, K):
220-
221-
def get_configs():
222-
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
223-
return [
224-
dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())
225-
]
226-
227-
@autotune(
228-
configs=get_configs(),
229-
warmup=3,
230-
rep=20,
231-
)
232-
@jit(
233-
out_idx=[-1],
234-
target="auto",
235-
)
236-
def kernel(
237-
BLOCK_N=None,
238-
reduce_threads=None,
219+
def get_block_template_configs():
220+
iter_params = dict(
221+
block_M=[2, 4, 8, 32, 64, 128],
222+
block_N=[2, 4, 8, 32, 64, 128],
223+
num_stages=[0, 1, 2, 3, 4],
224+
threads=[32, 64, 128, 256])
225+
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
226+
227+
228+
@tl.autotune(
229+
configs=get_block_template_configs(),
230+
warmup=3,
231+
rep=20,
232+
)
233+
@tl.jit(
234+
pass_configs={
235+
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
236+
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
237+
},
238+
out_idx=[2],
239+
)
240+
def gemv_alloc_reducer(M,
241+
N,
242+
block_M=128,
243+
block_N=128,
244+
num_stages=2,
245+
threads=256,
246+
dtype: str = "float16",
247+
accum_dtype: str = "float"):
248+
249+
@T.prim_func
250+
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M,
251+
dtype)): # type: ignore
252+
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
253+
o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all")
254+
T.clear(o_reducer)
255+
for i0_n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages):
256+
a_smem = T.alloc_shared((block_M, block_N), dtype)
257+
T.copy(a[i0_m * block_M, i0_n * block_N], a_smem)
258+
a_frag = T.alloc_fragment((block_M, block_N), dtype)
259+
T.copy(a_smem, a_frag)
260+
x_frag = T.alloc_fragment(block_N, dtype)
261+
T.copy(x[i0_n * block_N], x_frag)
262+
for i1_m, i1_n in T.Parallel(block_M, block_N):
263+
o_reducer[i1_m] += a_frag[i1_m, i1_n] * x_frag[i1_n]
264+
T.finalize_reducer(o_reducer)
265+
T.copy(o_reducer, o[i0_m * block_M])
266+
267+
return main
268+
269+
270+
def get_thread_template_configs():
271+
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
272+
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
273+
274+
275+
@autotune(
276+
configs=get_thread_template_configs(),
277+
warmup=3,
278+
rep=20,
279+
)
280+
@jit(
281+
out_idx=[-1],
282+
target="auto",
283+
)
284+
def get_autotuned_kernel(
285+
N,
286+
K,
287+
BLOCK_N=None,
288+
reduce_threads=None,
289+
):
290+
dtype = "float16"
291+
accum_dtype = "float"
292+
MAX_TRANSACTION_SIZE_IN_BITS = 128
293+
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
294+
BLOCK_K = reduce_threads * TILE_K
295+
296+
@T.prim_func
297+
def main(
298+
A: T.Tensor((K,), dtype),
299+
B: T.Tensor((N, K), dtype),
300+
C: T.Tensor((N,), dtype),
239301
):
240-
dtype = "float16"
241-
accum_dtype = "float"
242-
MAX_TRANSACTION_SIZE_IN_BITS = 128
243-
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
244-
BLOCK_K = reduce_threads * TILE_K
245-
246-
@T.prim_func
247-
def main(
248-
A: T.Tensor((K,), dtype),
249-
B: T.Tensor((N, K), dtype),
250-
C: T.Tensor((N,), dtype),
251-
):
252-
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
253-
tn = T.get_thread_binding(0)
254-
tk = T.get_thread_binding(1)
255-
A_local = T.alloc_local((TILE_K,), dtype)
256-
B_local = T.alloc_local((TILE_K,), dtype)
257-
C_accum = T.alloc_local((1,), accum_dtype)
258-
259-
T.clear(C_accum)
260-
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
261-
for k in T.vectorized(TILE_K):
262-
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
263-
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
264-
for k in T.serial(TILE_K):
265-
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(
266-
accum_dtype)
267-
C_reduced = T.alloc_local((1,), accum_dtype)
268-
with T.attr(
269-
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
270-
"reduce_scope",
271-
T.reinterpret(T.uint64(0), dtype="handle"),
272-
):
273-
T.evaluate(
274-
T.tvm_thread_allreduce(
275-
T.uint32(1),
276-
C_accum[0],
277-
True,
278-
C_reduced[0],
279-
tk,
280-
dtype="handle",
281-
))
282-
283-
C[bn * BLOCK_N + tn] = C_reduced[0]
284-
285-
return main
286-
287-
return kernel()
302+
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
303+
tn = T.get_thread_binding(0)
304+
tk = T.get_thread_binding(1)
305+
A_local = T.alloc_local((TILE_K,), dtype)
306+
B_local = T.alloc_local((TILE_K,), dtype)
307+
C_accum = T.alloc_local((1,), accum_dtype)
308+
309+
T.clear(C_accum)
310+
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
311+
for k in T.vectorized(TILE_K):
312+
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
313+
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
314+
for k in T.serial(TILE_K):
315+
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
316+
C_reduced = T.alloc_local((1,), accum_dtype)
317+
with T.attr(
318+
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
319+
"reduce_scope",
320+
T.reinterpret(T.uint64(0), dtype="handle"),
321+
):
322+
T.evaluate(
323+
T.tvm_thread_allreduce(
324+
T.uint32(1),
325+
C_accum[0],
326+
True,
327+
C_reduced[0],
328+
tk,
329+
dtype="handle",
330+
))
331+
332+
C[bn * BLOCK_N + tn] = C_reduced[0]
333+
334+
return main
288335

289336

290337
def check_correctness_and_bench(kernel, N, K, bench_ref=True):
@@ -297,7 +344,7 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
297344
print(f"TileLang Latency: {latency} ms\n")
298345

299346

300-
def main():
347+
def main(do_bench: bool = True):
301348
parser = argparse.ArgumentParser(description="GEMV Example")
302349
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
303350
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
@@ -308,16 +355,23 @@ def main():
308355
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K)
309356
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K)
310357
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K)
358+
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K)
359+
311360
print("Test passed!")
312361

313-
best_result = get_best_config(N, K)
314-
best_config = best_result.config
315-
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
316-
profiler = kernel.get_profiler()
317-
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
318-
print(f"Torch Latency: {latency} ms")
319-
latency = profiler.do_bench(kernel, warmup=500)
320-
print(f"TileLang Latency: {latency} ms\n")
362+
if not do_bench:
363+
best_result = get_autotuned_kernel(N, K)
364+
best_config = best_result.config
365+
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
366+
profiler = kernel.get_profiler()
367+
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
368+
print(f"Torch Latency: {latency} ms")
369+
tilelang_thread_latency = profiler.do_bench(kernel, warmup=500)
370+
print(f"TileLang SIMT Latency: {tilelang_thread_latency} ms\n")
371+
kernel = gemv_alloc_reducer(N, K)
372+
profiler = kernel.get_profiler()
373+
tilelang_tile_latency = profiler.do_bench(kernel, warmup=500)
374+
print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n")
321375

322376

323377
if __name__ == "__main__":

examples/gemv/test_example_gemv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def test_example_gemv():
7-
example_gemv.main()
7+
example_gemv.main(do_bench=False)
88

99

1010
if __name__ == "__main__":

0 commit comments

Comments
 (0)