1- from  asyncio  import  threads 
21from  tilelang  import  tvm  as  tvm 
32import  tilelang .testing 
43
@@ -90,7 +89,9 @@ def run_gemm_ss(
9089            tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
9190            tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
9291        })
93-     profiler  =  kernel .get_profiler ()
92+ 
93+     print (kernel .get_kernel_source ())
94+     profiler  =  kernel .get_profiler (tensor_supply_type = tilelang .TensorSupplyType .Normal )
9495
9596    def  ref_program (A , B ):
9697        import  torch 
@@ -109,11 +110,21 @@ def ref_program(A, B):
109110def  test_gemm_ss ():
110111    # More test case can be found in kernel/test_tilelang_kernel_gemm.py 
111112    # GEMM tests for float16 
112-     run_gemm_ss (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
113-     run_gemm_ss (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
114-     run_gemm_ss (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
115-     run_gemm_ss (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
116-     
113+     run_gemm_ss (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
114+     run_gemm_ss (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
115+     run_gemm_ss (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
116+     run_gemm_ss (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
117+     # n8 test 
118+     run_gemm_ss (128 , 8 , 32 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 0 , 128 )
119+ 
120+     # int8 test 
121+     run_gemm_ss (128 , 128 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
122+     run_gemm_ss (128 , 128 , 128 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
123+     run_gemm_ss (128 , 128 , 128 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
124+     run_gemm_ss (128 , 128 , 128 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
125+ 
126+     # float8 tests 
127+     run_gemm_ss (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
117128
118129
119130def  matmul_rs (
@@ -208,7 +219,7 @@ def run_gemm_rs(
208219            tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
209220            tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
210221        })
211-     profiler  =  kernel .get_profiler ()
222+     profiler  =  kernel .get_profiler (tensor_supply_type = tilelang . TensorSupplyType . Normal )
212223
213224    def  ref_program (A , B ):
214225        import  torch 
@@ -226,8 +237,22 @@ def ref_program(A, B):
226237
227238def  test_gemm_rs ():
228239    # GEMM tests for float16 
229-     run_gemm_rs (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 0 )
230-     run_gemm_rs (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 0 )
240+     run_gemm_rs (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
241+     run_gemm_rs (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
242+     run_gemm_rs (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
243+     run_gemm_rs (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
244+ 
245+     # n8 tests 
246+     run_gemm_rs (128 , 8 , 32 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 0 , 128 )
247+ 
248+     # int8 tests 
249+     run_gemm_rs (128 , 128 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
250+     run_gemm_rs (128 , 128 , 128 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
251+     run_gemm_rs (128 , 128 , 128 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
252+     run_gemm_rs (128 , 128 , 128 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
253+ 
254+     # float8 tests 
255+     run_gemm_rs (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
231256
232257
233258def  matmul_sr (
@@ -322,7 +347,7 @@ def run_gemm_sr(
322347            tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
323348            tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
324349        })
325-     profiler  =  kernel .get_profiler ()
350+     profiler  =  kernel .get_profiler (tensor_supply_type = tilelang . TensorSupplyType . Normal )
326351
327352    def  ref_program (A , B ):
328353        import  torch 
@@ -345,6 +370,18 @@ def test_gemm_sr():
345370    run_gemm_sr (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
346371    run_gemm_sr (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
347372
373+     # n8 tests 
374+     run_gemm_sr (128 , 8 , 32 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 0 , 128 )
375+ 
376+     # int8 tests 
377+     run_gemm_sr (128 , 128 , 32 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
378+     run_gemm_sr (128 , 128 , 32 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
379+     run_gemm_sr (128 , 128 , 32 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
380+     run_gemm_sr (128 , 128 , 32 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
381+ 
382+     # float8 tests 
383+     run_gemm_sr (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
384+ 
348385
349386def  matmul_rr (
350387    M ,
@@ -442,7 +479,7 @@ def run_gemm_rr(
442479            tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
443480            tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
444481        })
445-     profiler  =  kernel .get_profiler ()
482+     profiler  =  kernel .get_profiler (tensor_supply_type = tilelang . TensorSupplyType . Normal )
446483
447484    def  ref_program (A , B ):
448485        import  torch 
@@ -465,40 +502,20 @@ def test_gemm_rr():
465502    run_gemm_rr (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
466503    run_gemm_rr (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
467504    run_gemm_rr (512 , 1024 , 768 , False , True , "bfloat16" , "bfloat16" , "float" , 128 , 256 , 32 , 2 )
505+     # n8 tests 
506+     run_gemm_rr (128 , 8 , 128 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 2 )
507+     run_gemm_rr (128 , 8 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 8 , 32 , 2 )
508+ 
509+     # int8 tests 
510+     run_gemm_rr (128 , 128 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
511+     run_gemm_rr (128 , 128 , 128 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
512+     run_gemm_rr (128 , 128 , 128 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
513+     run_gemm_rr (128 , 128 , 128 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
514+ 
515+     # float8 tests 
516+     run_gemm_rr (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
468517
469518
470519if  __name__  ==  "__main__" :
471520    # tilelang.testing.main() 
472-     tilelang .disable_cache ()
473-     # test_gemm_ss() 
474-     # test_gemm_sr() 
475-     # test_gemm_rs() 
476-     # test_gemm_rr() 
477-     
478-     # run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2) 
479-     # tilelang.testing.set_random_seed(42) 
480-     run_gemm_ss (128 , 128 , 128 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 1 )
481-     # print("gemm fp16 nt ss done") 
482-     # exit() 
483-     
484-     # run_gemm_rs(128, 128, 32, False, True, "float16", "float16", "float16", 128, 128, 32, 0) 
485-     # print("gemm fp16 nt rs done") 
486-     # run_gemm_rs(128, 128, 32, False, False, "float16", "float16", "float16", 128, 128, 32, 0) 
487-     # print("gemm fp16 nn rs done") 
488-     # run_gemm_rs(128, 128, 32, True, False, "float16", "float16", "float16", 128, 128, 32, 0) 
489-     # print("gemm fp16 tn rs done") 
490-     # run_gemm_rs(128, 128, 32, True, True, "float16", "float16", "float16", 128, 128, 32, 0) 
491-     # print("gemm fp16 tt rs done") 
492- 
493-     # run_gemm_rs(16, 16, 16, True, False, "float16", "float16", "float16", 16, 16, 16, 0, 32) 
494- 
495-     # run_gemm_rr(128, 128, 32, False, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) 
496-     # print("gemm bf16 nn rr done") 
497-     # run_gemm_rr(128, 128, 32, False, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) 
498-     # print("gemm bf16 nt rr done") 
499-     # run_gemm_rr(128, 128, 32, True, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) 
500-     # print("gemm bf16 tn rr done") 
501-     # run_gemm_rr(128, 128, 32, True, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0) 
502-     # print("gemm bf16 tt rr done") 
503-     
504-     
521+     run_gemm_rr (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
0 commit comments