55import tilelang .language as T
66from tilelang .utils import map_torch_type
77
8+
89@tl .jit
910def ptr_null_test (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" ):
1011
@@ -46,16 +47,17 @@ def main(
4647
4748 return main
4849
50+
4951@tl .jit
5052def tensor_null_test (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" ):
5153
5254 @T .prim_func
5355 def main (
54- A : T .Tensor ((M , K ), dtype ),
55- B : T .Tensor ((K , N ), dtype ),
56- C : T .Tensor ((M , N ), accum_dtype ),
57- Bias : T .Tensor ((N ), accum_dtype ),
58- with_bias : T .bool ,
56+ A : T .Tensor ((M , K ), dtype ),
57+ B : T .Tensor ((K , N ), dtype ),
58+ C : T .Tensor ((M , N ), accum_dtype ),
59+ Bias : T .Tensor ((N ), accum_dtype ),
60+ with_bias : T .bool ,
5961 ):
6062 # Initialize Kernel Context
6163 with T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = 128 ) as (bx , by ):
@@ -79,6 +81,7 @@ def main(
7981
8082 return main
8183
84+
8285def run_test (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" ):
8386 func = ptr_null_test (M , N , K , block_M , block_N , block_K , dtype , accum_dtype )
8487
@@ -104,8 +107,10 @@ def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="f
104107 func (a , b , c , d , True )
105108 torch .testing .assert_close (c , ref_with_bias , atol = 1e-2 , rtol = 1e-2 )
106109
110+
107111def test_nullptr ():
108112 run_test (1024 , 1024 , 1024 , 128 , 128 , 32 )
109113
114+
110115if __name__ == "__main__" :
111- tilelang .testing .main ()
116+ tilelang .testing .main ()
0 commit comments