2222
2323
2424@T .prim_func
25- def csrmm (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle ) -> None :
26- n = T .var ("int32" )
27- m = T .var ("int32" )
28- k = T .var ("int32" )
29- nnz = T .var ("int32" )
25+ def csrmm (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , m : T .int32 , n : T .int32 , k : T .int32 , nnz : T .int32 ) -> None :
3026 I = T .dense_fixed (m )
3127 J = T .sparse_variable ((n , m + 1 , nnz ), (indptr , indices ), "int32" )
3228 K = T .dense_fixed (k )
@@ -40,26 +36,22 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha
4036
4137
4238@T .prim_func
43- def csrmm_tir (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle ) -> None :
39+ def csrmm_tir (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , M : T . int32 , N : T . int32 , K : T . int32 , nnz : T . int32 ) -> None :
4440 T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
45- n = T .var ("int32" )
46- m = T .var ("int32" )
47- k = T .var ("int32" )
48- nnz = T .var ("int32" )
4941 A_data = T .match_buffer (a , (nnz ,), "float32" )
50- B = T .match_buffer (b , (n , k ), "float32" )
51- C = T .match_buffer (c , (m , k ), "float32" )
52- A_indptr = T .match_buffer (indptr , (m + 1 ,), "int32" )
42+ B = T .match_buffer (b , (N * K , ), "float32" )
43+ C = T .match_buffer (c , (M * K , ), "float32" )
44+ A_indptr = T .match_buffer (indptr , (M + 1 ,), "int32" )
5345 A_indices = T .match_buffer (indices , (nnz ,), "int32" )
54- for i , k in T .grid (m , k ):
46+ for i , k in T .grid (M , K ):
5547 with T .block ("spmm_outer" ):
5648 vi , vk = T .axis .remap ("SS" , [i , k ])
5749 with T .init ():
58- C [vi , vk ] = 0.
50+ C [vi * K + vk ] = 0.
5951 for j in T .serial (0 , A_indptr [vi + 1 ] - A_indptr [vi ]):
6052 with T .block ("spmm_inner" ):
61- vj = T .axis .R (n , j + A_indptr [vi ])
62- C [vi , vk ] = C [vi , vk ] + A_data [vj ] * B [A_indices [vj ], vk ]
53+ vj = T .axis .R (N , j + A_indptr [vi ])
54+ C [vi * K + vk ] = C [vi * K + vk ] + A_data [vj ] * B [A_indices [vj ] * K + vk ]
6355
6456
6557def test_csrmm ():
@@ -70,7 +62,12 @@ def test_csrmm():
7062 y = np .zeros ((4096 , 256 )).astype ("float32" )
7163
7264 # specialize function
73- sch = tir .Schedule (csrmm_tir )
65+ _ , _ , _ , _ , _ , m , n , k , nnz = csrmm_tir .params
66+ sch = tir .Schedule (
67+ csrmm_tir .specialize (
68+ {m : 4096 , n : 4096 , k : 256 , nnz : A .nnz }
69+ )
70+ )
7471 blk_outer = sch .get_block ("spmm_outer" )
7572 i , k = sch .get_loops (blk_outer )
7673 sch .bind (i , "blockIdx.x" )
@@ -80,14 +77,15 @@ def test_csrmm():
8077 A_indptr = tvm .nd .array (A .indptr .astype ("int32" ), device = tvm .cuda (0 ))
8178 A_indices = tvm .nd .array (A .indices .astype ("int32" ), device = tvm .cuda (0 ))
8279 A_data = tvm .nd .array (A .data .astype ("float32" ), device = tvm .cuda (0 ))
83- X_nd = tvm .nd .array (x , device = tvm .cuda (0 ))
84- Y_nd = tvm .nd .array (y , device = tvm .cuda (0 ))
80+ X_nd = tvm .nd .array (x . reshape ( - 1 ) , device = tvm .cuda (0 ))
81+ Y_nd = tvm .nd .array (y . reshape ( - 1 ) , device = tvm .cuda (0 ))
8582
8683 # build function
8784 f = tvm .build (sch .mod , target = 'cuda' )
8885 f (A_data , X_nd , Y_nd , A_indptr , A_indices )
8986
90- assert np .allclose (y_ground_truth , Y_nd .numpy ())
87+ # assertion
88+ assert np .allclose (y_ground_truth .reshape (- 1 ), Y_nd .numpy ())
9189
9290
9391if __name__ == "__main__" :
0 commit comments