55import tilelang .language as T
66
77
8+ @tilelang .jit
89def vectorize_test (N , M , stride_A , stride_B ):
910 assert N % 128 == 0 and M % 128 == 0
1011
1112 @T .prim_func
1213 def main (
13- A : T .StridedTensor [(N , M ), (1 , stride_A ), "float32" ],
14- B : T .StridedTensor [(N , M ), (1 , stride_B ), "float32" ],
14+ A : T .StridedTensor [(N , M ), (1 , stride_A ), "float32" ], # noqa: F821
15+ B : T .StridedTensor [(N , M ), (1 , stride_B ), "float32" ], # noqa: F821
1516 ):
1617 with T .Kernel (M // 128 , threads = 128 ) as (bx ):
1718 tx = T .get_thread_binding (0 )
@@ -26,8 +27,7 @@ def main(
2627def run_vectorize (N , M , stride_A , stride_B ):
2728 assert stride_A >= N and stride_B >= N
2829
29- program = vectorize_test (N , M , stride_A , stride_B )
30- jit_kernel = tl .compile (program , target = "cuda" , execution_backend = "cython" )
30+ jit_kernel = vectorize_test (N , M , stride_A , stride_B )
3131
3232 base_a = torch .randn (stride_A , M , device = "cuda" , dtype = torch .float32 )
3333 base_b = torch .zeros (stride_B , M , device = "cuda" , dtype = torch .float32 )
0 commit comments