@@ -216,75 +216,122 @@ def main(
216216 return main
217217
218218
219- def get_best_config (N , K ):
220-
221- def get_configs ():
222- iter_params = dict (BLOCK_N = [2 , 4 , 8 , 32 , 64 , 128 ], reduce_threads = [4 , 8 , 32 ])
223- return [
224- dict (zip (iter_params , values )) for values in itertools .product (* iter_params .values ())
225- ]
226-
227- @autotune (
228- configs = get_configs (),
229- warmup = 3 ,
230- rep = 20 ,
231- )
232- @jit (
233- out_idx = [- 1 ],
234- target = "auto" ,
235- )
236- def kernel (
237- BLOCK_N = None ,
238- reduce_threads = None ,
219+ def get_block_template_configs ():
220+ iter_params = dict (
221+ block_M = [2 , 4 , 8 , 32 , 64 , 128 ],
222+ block_N = [2 , 4 , 8 , 32 , 64 , 128 ],
223+ num_stages = [0 , 1 , 2 , 3 , 4 ],
224+ threads = [32 , 64 , 128 , 256 ])
225+ return [dict (zip (iter_params , values )) for values in itertools .product (* iter_params .values ())]
226+
227+
228+ @tl .autotune (
229+ configs = get_block_template_configs (),
230+ warmup = 3 ,
231+ rep = 20 ,
232+ )
233+ @tl .jit (
234+ pass_configs = {
235+ tl .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
236+ tl .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
237+ },
238+ out_idx = [2 ],
239+ )
240+ def gemv_alloc_reducer (M ,
241+ N ,
242+ block_M = 128 ,
243+ block_N = 128 ,
244+ num_stages = 2 ,
245+ threads = 256 ,
246+ dtype : str = "float16" ,
247+ accum_dtype : str = "float" ):
248+
249+ @T .prim_func
250+ def main (a : T .Tensor ((M , N ), dtype ), x : T .Tensor (N , dtype ), o : T .Tensor (M ,
251+ dtype )): # type: ignore
252+ with T .Kernel (T .ceildiv (M , block_M ), threads = threads ) as i0_m :
253+ o_reducer = T .alloc_reducer (block_M , accum_dtype , replication = "all" )
254+ T .clear (o_reducer )
255+ for i0_n in T .Pipelined (T .ceildiv (N , block_N ), num_stages = num_stages ):
256+ a_smem = T .alloc_shared ((block_M , block_N ), dtype )
257+ T .copy (a [i0_m * block_M , i0_n * block_N ], a_smem )
258+ a_frag = T .alloc_fragment ((block_M , block_N ), dtype )
259+ T .copy (a_smem , a_frag )
260+ x_frag = T .alloc_fragment (block_N , dtype )
261+ T .copy (x [i0_n * block_N ], x_frag )
262+ for i1_m , i1_n in T .Parallel (block_M , block_N ):
263+ o_reducer [i1_m ] += a_frag [i1_m , i1_n ] * x_frag [i1_n ]
264+ T .finalize_reducer (o_reducer )
265+ T .copy (o_reducer , o [i0_m * block_M ])
266+
267+ return main
268+
269+
270+ def get_thread_template_configs ():
271+ iter_params = dict (BLOCK_N = [2 , 4 , 8 , 32 , 64 , 128 ], reduce_threads = [4 , 8 , 32 ])
272+ return [dict (zip (iter_params , values )) for values in itertools .product (* iter_params .values ())]
273+
274+
275+ @autotune (
276+ configs = get_thread_template_configs (),
277+ warmup = 3 ,
278+ rep = 20 ,
279+ )
280+ @jit (
281+ out_idx = [- 1 ],
282+ target = "auto" ,
283+ )
284+ def get_autotuned_kernel (
285+ N ,
286+ K ,
287+ BLOCK_N = None ,
288+ reduce_threads = None ,
289+ ):
290+ dtype = "float16"
291+ accum_dtype = "float"
292+ MAX_TRANSACTION_SIZE_IN_BITS = 128
293+ TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType (dtype ).bits
294+ BLOCK_K = reduce_threads * TILE_K
295+
296+ @T .prim_func
297+ def main (
298+ A : T .Tensor ((K ,), dtype ),
299+ B : T .Tensor ((N , K ), dtype ),
300+ C : T .Tensor ((N ,), dtype ),
239301 ):
240- dtype = "float16"
241- accum_dtype = "float"
242- MAX_TRANSACTION_SIZE_IN_BITS = 128
243- TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType (dtype ).bits
244- BLOCK_K = reduce_threads * TILE_K
245-
246- @T .prim_func
247- def main (
248- A : T .Tensor ((K ,), dtype ),
249- B : T .Tensor ((N , K ), dtype ),
250- C : T .Tensor ((N ,), dtype ),
251- ):
252- with T .Kernel (T .ceildiv (N , BLOCK_N ), threads = (BLOCK_N , reduce_threads )) as bn :
253- tn = T .get_thread_binding (0 )
254- tk = T .get_thread_binding (1 )
255- A_local = T .alloc_local ((TILE_K ,), dtype )
256- B_local = T .alloc_local ((TILE_K ,), dtype )
257- C_accum = T .alloc_local ((1 ,), accum_dtype )
258-
259- T .clear (C_accum )
260- for bk in T .serial (T .ceildiv (K , BLOCK_K )):
261- for k in T .vectorized (TILE_K ):
262- A_local [k ] = A [bk * BLOCK_K + tk * TILE_K + k ]
263- B_local [k ] = B [bn * BLOCK_N + tn , bk * BLOCK_K + tk * TILE_K + k ]
264- for k in T .serial (TILE_K ):
265- C_accum [0 ] += A_local [k ].astype (accum_dtype ) * B_local [k ].astype (
266- accum_dtype )
267- C_reduced = T .alloc_local ((1 ,), accum_dtype )
268- with T .attr (
269- T .comm_reducer (lambda x , y : x + y , [T .Cast (accum_dtype , 0 )]),
270- "reduce_scope" ,
271- T .reinterpret (T .uint64 (0 ), dtype = "handle" ),
272- ):
273- T .evaluate (
274- T .tvm_thread_allreduce (
275- T .uint32 (1 ),
276- C_accum [0 ],
277- True ,
278- C_reduced [0 ],
279- tk ,
280- dtype = "handle" ,
281- ))
282-
283- C [bn * BLOCK_N + tn ] = C_reduced [0 ]
284-
285- return main
286-
287- return kernel ()
302+ with T .Kernel (T .ceildiv (N , BLOCK_N ), threads = (BLOCK_N , reduce_threads )) as bn :
303+ tn = T .get_thread_binding (0 )
304+ tk = T .get_thread_binding (1 )
305+ A_local = T .alloc_local ((TILE_K ,), dtype )
306+ B_local = T .alloc_local ((TILE_K ,), dtype )
307+ C_accum = T .alloc_local ((1 ,), accum_dtype )
308+
309+ T .clear (C_accum )
310+ for bk in T .serial (T .ceildiv (K , BLOCK_K )):
311+ for k in T .vectorized (TILE_K ):
312+ A_local [k ] = A [bk * BLOCK_K + tk * TILE_K + k ]
313+ B_local [k ] = B [bn * BLOCK_N + tn , bk * BLOCK_K + tk * TILE_K + k ]
314+ for k in T .serial (TILE_K ):
315+ C_accum [0 ] += A_local [k ].astype (accum_dtype ) * B_local [k ].astype (accum_dtype )
316+ C_reduced = T .alloc_local ((1 ,), accum_dtype )
317+ with T .attr (
318+ T .comm_reducer (lambda x , y : x + y , [T .Cast (accum_dtype , 0 )]),
319+ "reduce_scope" ,
320+ T .reinterpret (T .uint64 (0 ), dtype = "handle" ),
321+ ):
322+ T .evaluate (
323+ T .tvm_thread_allreduce (
324+ T .uint32 (1 ),
325+ C_accum [0 ],
326+ True ,
327+ C_reduced [0 ],
328+ tk ,
329+ dtype = "handle" ,
330+ ))
331+
332+ C [bn * BLOCK_N + tn ] = C_reduced [0 ]
333+
334+ return main
288335
289336
290337def check_correctness_and_bench (kernel , N , K , bench_ref = True ):
@@ -297,7 +344,7 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
297344 print (f"TileLang Latency: { latency } ms\n " )
298345
299346
300- def main ():
347+ def main (do_bench : bool = True ):
301348 parser = argparse .ArgumentParser (description = "GEMV Example" )
302349 parser .add_argument ("--n" , type = int , default = 1024 , help = "Matrix dimension N" )
303350 parser .add_argument ("--k" , type = int , default = 1024 , help = "Matrix dimension K" )
@@ -308,16 +355,23 @@ def main():
308355 check_correctness_and_bench (splitk_gemv (N , K , 32 , 32 , 32 ), N , K )
309356 check_correctness_and_bench (splitk_gemv_vectorized (N , K , 2 , 32 ), N , K )
310357 check_correctness_and_bench (splitk_gemv_vectorized_tvm (N , K , 2 , 32 ), N , K )
358+ check_correctness_and_bench (gemv_alloc_reducer (N , K , block_M = 128 , block_N = 128 ), N , K )
359+
311360 print ("Test passed!" )
312361
313- best_result = get_best_config (N , K )
314- best_config = best_result .config
315- kernel = splitk_gemv_vectorized_tvm (N , K , ** best_config )
316- profiler = kernel .get_profiler ()
317- latency = profiler .do_bench (lambda x , y : x @ y .T , warmup = 500 )
318- print (f"Torch Latency: { latency } ms" )
319- latency = profiler .do_bench (kernel , warmup = 500 )
320- print (f"TileLang Latency: { latency } ms\n " )
362+ if not do_bench :
363+ best_result = get_autotuned_kernel (N , K )
364+ best_config = best_result .config
365+ kernel = splitk_gemv_vectorized_tvm (N , K , ** best_config )
366+ profiler = kernel .get_profiler ()
367+ latency = profiler .do_bench (lambda x , y : x @ y .T , warmup = 500 )
368+ print (f"Torch Latency: { latency } ms" )
369+ tilelang_thread_latency = profiler .do_bench (kernel , warmup = 500 )
370+ print (f"TileLang SIMT Latency: { tilelang_thread_latency } ms\n " )
371+ kernel = gemv_alloc_reducer (N , K )
372+ profiler = kernel .get_profiler ()
373+ tilelang_tile_latency = profiler .do_bench (kernel , warmup = 500 )
374+ print (f"TileLang BlockReduce Latency: { tilelang_tile_latency } ms\n " )
321375
322376
323377if __name__ == "__main__" :
0 commit comments