1- from asyncio import threads
21from tilelang import tvm as tvm
32import tilelang .testing
43
@@ -90,7 +89,9 @@ def run_gemm_ss(
9089 tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
9190 tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
9291 })
93- profiler = kernel .get_profiler ()
92+
93+ print (kernel .get_kernel_source ())
94+ profiler = kernel .get_profiler (tensor_supply_type = tilelang .TensorSupplyType .Normal )
9495
9596 def ref_program (A , B ):
9697 import torch
@@ -109,11 +110,21 @@ def ref_program(A, B):
109110def test_gemm_ss ():
110111 # More test case can be found in kernel/test_tilelang_kernel_gemm.py
111112 # GEMM tests for float16
112- run_gemm_ss (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
113- run_gemm_ss (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
114- run_gemm_ss (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
115- run_gemm_ss (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 0 )
116-
113+ run_gemm_ss (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
114+ run_gemm_ss (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
115+ run_gemm_ss (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
116+ run_gemm_ss (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 2 )
117+ # n8 test
118+ run_gemm_ss (128 , 8 , 32 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 0 , 128 )
119+
120+ # int8 test
121+ run_gemm_ss (128 , 128 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
122+ run_gemm_ss (128 , 128 , 128 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
123+ run_gemm_ss (128 , 128 , 128 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
124+ run_gemm_ss (128 , 128 , 128 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
125+
126+ # float8 tests
127+ run_gemm_ss (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
117128
118129
119130def matmul_rs (
@@ -208,7 +219,7 @@ def run_gemm_rs(
208219 tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
209220 tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
210221 })
211- profiler = kernel .get_profiler ()
222+ profiler = kernel .get_profiler (tensor_supply_type = tilelang . TensorSupplyType . Normal )
212223
213224 def ref_program (A , B ):
214225 import torch
@@ -226,8 +237,22 @@ def ref_program(A, B):
226237
227238def test_gemm_rs ():
228239 # GEMM tests for float16
229- run_gemm_rs (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 0 )
230- run_gemm_rs (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 0 )
240+ run_gemm_rs (512 , 1024 , 768 , False , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
241+ run_gemm_rs (512 , 1024 , 768 , False , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
242+ run_gemm_rs (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
243+ run_gemm_rs (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
244+
245+ # n8 tests
246+ run_gemm_rs (128 , 8 , 32 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 0 , 128 )
247+
248+ # int8 tests
249+ run_gemm_rs (128 , 128 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
250+ run_gemm_rs (128 , 128 , 128 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
251+ run_gemm_rs (128 , 128 , 128 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
252+ run_gemm_rs (128 , 128 , 128 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
253+
254+ # float8 tests
255+ run_gemm_rs (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
231256
232257
233258def matmul_sr (
@@ -322,7 +347,7 @@ def run_gemm_sr(
322347 tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
323348 tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
324349 })
325- profiler = kernel .get_profiler ()
350+ profiler = kernel .get_profiler (tensor_supply_type = tilelang . TensorSupplyType . Normal )
326351
327352 def ref_program (A , B ):
328353 import torch
@@ -345,6 +370,18 @@ def test_gemm_sr():
345370 run_gemm_sr (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
346371 run_gemm_sr (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
347372
373+ # n8 tests
374+ run_gemm_sr (128 , 8 , 32 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 0 , 128 )
375+
376+ # int8 tests
377+ run_gemm_sr (128 , 128 , 32 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
378+ run_gemm_sr (128 , 128 , 32 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
379+ run_gemm_sr (128 , 128 , 32 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
380+ run_gemm_sr (128 , 128 , 32 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
381+
382+ # float8 tests
383+ run_gemm_sr (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
384+
348385
349386def matmul_rr (
350387 M ,
@@ -442,7 +479,7 @@ def run_gemm_rr(
442479 tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
443480 tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
444481 })
445- profiler = kernel .get_profiler ()
482+ profiler = kernel .get_profiler (tensor_supply_type = tilelang . TensorSupplyType . Normal )
446483
447484 def ref_program (A , B ):
448485 import torch
@@ -465,40 +502,20 @@ def test_gemm_rr():
465502 run_gemm_rr (512 , 1024 , 768 , True , False , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
466503 run_gemm_rr (512 , 1024 , 768 , True , True , "float16" , "float16" , "float16" , 128 , 256 , 32 , 2 )
467504 run_gemm_rr (512 , 1024 , 768 , False , True , "bfloat16" , "bfloat16" , "float" , 128 , 256 , 32 , 2 )
505+ # n8 tests
506+ run_gemm_rr (128 , 8 , 128 , False , True , "float16" , "float16" , "float16" , 128 , 8 , 32 , 2 )
507+ run_gemm_rr (128 , 8 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 8 , 32 , 2 )
508+
509+ # int8 tests
510+ run_gemm_rr (128 , 128 , 128 , False , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
511+ run_gemm_rr (128 , 128 , 128 , False , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
512+ run_gemm_rr (128 , 128 , 128 , True , False , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
513+ run_gemm_rr (128 , 128 , 128 , True , True , "int8" , "int8" , "int32" , 128 , 128 , 32 , 2 )
514+
515+ # float8 tests
516+ run_gemm_rr (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
468517
469518
470519if __name__ == "__main__" :
471520 # tilelang.testing.main()
472- tilelang .disable_cache ()
473- # test_gemm_ss()
474- # test_gemm_sr()
475- # test_gemm_rs()
476- # test_gemm_rr()
477-
478- # run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
479- # tilelang.testing.set_random_seed(42)
480- run_gemm_ss (128 , 128 , 128 , False , True , "float16" , "float16" , "float16" , 128 , 128 , 32 , 1 )
481- # print("gemm fp16 nt ss done")
482- # exit()
483-
484- # run_gemm_rs(128, 128, 32, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
485- # print("gemm fp16 nt rs done")
486- # run_gemm_rs(128, 128, 32, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
487- # print("gemm fp16 nn rs done")
488- # run_gemm_rs(128, 128, 32, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
489- # print("gemm fp16 tn rs done")
490- # run_gemm_rs(128, 128, 32, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
491- # print("gemm fp16 tt rs done")
492-
493- # run_gemm_rs(16, 16, 16, True, False, "float16", "float16", "float16", 16, 16, 16, 0, 32)
494-
495- # run_gemm_rr(128, 128, 32, False, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
496- # print("gemm bf16 nn rr done")
497- # run_gemm_rr(128, 128, 32, False, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
498- # print("gemm bf16 nt rr done")
499- # run_gemm_rr(128, 128, 32, True, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
500- # print("gemm bf16 tn rr done")
501- # run_gemm_rr(128, 128, 32, True, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
502- # print("gemm bf16 tt rr done")
503-
504-
521+ run_gemm_rr (128 , 128 , 128 , True , True , "float8_e5m2" , "float8_e5m2" , "float32" , 128 , 128 , 32 , 2 )
0 commit comments