1515# specific language governing permissions and limitations
1616# under the License.
1717import tvm
18+ from tvm .runtime .ndarray import device
1819import tvm .tir as tir
1920import scipy .sparse as sp
2021import numpy as np
@@ -36,49 +37,90 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha
3637
3738
3839@T .prim_func
39- def csrmm_tir (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , M : T .int32 , N : T .int32 , K : T .int32 , nnz : T .int32 ) -> None :
40+ def csrmm_tir (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , M : T .int32 , N : T .int32 , K : T .int32 , NNZ : T .int32 ) -> None :
4041 T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
41- A_data = T .match_buffer (a , (nnz ,), "float32" )
42+ A_data = T .match_buffer (a , (NNZ ,), "float32" )
4243 B = T .match_buffer (b , (N * K ,), "float32" )
4344 C = T .match_buffer (c , (M * K ,), "float32" )
4445 A_indptr = T .match_buffer (indptr , (M + 1 ,), "int32" )
45- A_indices = T .match_buffer (indices , (nnz ,), "int32" )
46+ A_indices = T .match_buffer (indices , (NNZ ,), "int32" )
4647 for i , k in T .grid (M , K ):
4748 with T .block ("spmm_outer" ):
4849 vi , vk = T .axis .remap ("SS" , [i , k ])
4950 with T .init ():
5051 C [vi * K + vk ] = 0.
5152 for j in T .serial (0 , A_indptr [vi + 1 ] - A_indptr [vi ]):
5253 with T .block ("spmm_inner" ):
53- vj = T .axis .R (N , j + A_indptr [vi ])
54- C [vi * K + vk ] = C [vi * K + vk ] + A_data [vj ] * B [A_indices [vj ] * K + vk ]
54+ vj = T .axis .R (NNZ , j + A_indptr [vi ])
55+ C [vi * K + vk ] = C [vi * K + vk ] + \
56+ A_data [vj ] * B [A_indices [vj ] * K + vk ]
57+
58+
59+ @T .prim_func
60+ def bsrmm_tir (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , MB : T .int32 , NB : T .int32 , K : T .int32 , BLOCK_SIZE : T .int32 , NNZB : T .int32 ) -> None :
61+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
62+ A_data = T .match_buffer (a , (NNZB * BLOCK_SIZE * BLOCK_SIZE ), "float32" )
63+ B = T .match_buffer (b , (NB * BLOCK_SIZE * K ,), "float32" )
64+ C = T .match_buffer (c , (MB * BLOCK_SIZE * K ,), "float32" )
65+ A_indptr = T .match_buffer (indptr , (MB + 1 ,), "int32" )
66+ A_indices = T .match_buffer (indices , (NNZB ,), "int32" )
67+ for io , ii , ji , k in T .grid (MB , BLOCK_SIZE , BLOCK_SIZE , K ):
68+ with T .block ("spmm_outer" ):
69+ vio , vii , vji , vk = T .axis .remap ("SSSS" , [io , ii , ji , k ])
70+ with T .init ():
71+ C [(vio * BLOCK_SIZE + vii ) * K + vk ] = 0.
72+ for jo in T .serial (0 , A_indptr [vio + 1 ] - A_indptr [vio ]):
73+ with T .block ("spmm_inner" ):
74+ vjo = T .axis .R (NNZB , jo + A_indptr [vio ])
75+ C [(vio * BLOCK_SIZE + vii ) * K + vk ] = C [(vio * BLOCK_SIZE + vii ) * K + vk ] + A_data [(
76+ vjo * BLOCK_SIZE + vii ) * BLOCK_SIZE + vji ] * B [(A_indices [vjo ] * BLOCK_SIZE + vji ) * K + vk ]
77+
78+
79+ @T .prim_func
80+ def ellmm_tir (a : T .handle , b : T .handle , c : T .handle , indices : T .handle , M : T .int32 , N : T .int32 , K : T .int32 , NNZ_COLS : T .int32 ) -> None :
81+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
82+ A_data = T .match_buffer (a , (M * NNZ_COLS ,), "float32" )
83+ B = T .match_buffer (b , (N * K ,), "float32" )
84+ C = T .match_buffer (c , (M * K ,), "float32" )
85+ A_indices = T .match_buffer (indices , (M * NNZ_COLS ,), "int32" )
86+ for i , j , k in T .grid (M , NNZ_COLS , K ):
87+ with T .block ("spmm" ):
88+ vi , vj , vk = T .axis .remap ("SRS" , [i , j , k ])
89+ with T .init ():
90+ C [vi * K + vk ] = 0.
91+ C [vi * K + vk ] = C [vi * K + vk ] + A_data [vi * NNZ_COLS + vj ] * \
92+ B [A_indices [vi * NNZ_COLS + vj ] * K + vk ]
5593
5694
5795def test_csrmm ():
5896 # generate random input
59- A = sp .random (4096 , 4096 , dtype = "float32" , density = 0.0125 , format = 'csr' )
60- x = np .random .rand (4096 , 256 ).astype ("float32" )
97+ m = 4096
98+ n = 4096
99+ k = 256
100+ A = sp .random (m , n , dtype = "float32" , density = 0.0125 , format = 'csr' )
101+ nnz = A .nnz
102+ x = np .random .rand (n , k ).astype ("float32" )
61103 y_ground_truth = A * x
62- y = np .zeros ((4096 , 256 )).astype ("float32" )
104+ y = np .zeros ((m * k , )).astype ("float32" )
63105
64106 # specialize function
65- _ , _ , _ , _ , _ , m , n , k , nnz = csrmm_tir .params
107+ _ , _ , _ , _ , _ , M , N , K , NNZ = csrmm_tir .params
66108 sch = tir .Schedule (
67109 csrmm_tir .specialize (
68- {m : 4096 , n : 4096 , k : 256 , nnz : A . nnz }
110+ {M : m , N : n , K : k , NNZ : nnz }
69111 )
70112 )
71113 blk_outer = sch .get_block ("spmm_outer" )
72114 i , k = sch .get_loops (blk_outer )
73115 sch .bind (i , "blockIdx.x" )
74116 sch .bind (k , "threadIdx.x" )
75-
117+
76118 # convert numpy tensor to tvm ndarray
77119 A_indptr = tvm .nd .array (A .indptr .astype ("int32" ), device = tvm .cuda (0 ))
78120 A_indices = tvm .nd .array (A .indices .astype ("int32" ), device = tvm .cuda (0 ))
79121 A_data = tvm .nd .array (A .data .astype ("float32" ), device = tvm .cuda (0 ))
80122 X_nd = tvm .nd .array (x .reshape (- 1 ), device = tvm .cuda (0 ))
81- Y_nd = tvm .nd .array (y . reshape ( - 1 ) , device = tvm .cuda (0 ))
123+ Y_nd = tvm .nd .array (y , device = tvm .cuda (0 ))
82124
83125 # build function
84126 f = tvm .build (sch .mod , target = 'cuda' )
@@ -88,5 +130,102 @@ def test_csrmm():
88130 assert np .allclose (y_ground_truth .reshape (- 1 ), Y_nd .numpy ())
89131
90132
133+ def test_bsrmm ():
134+ # generate random input
135+ block_size = 1
136+ mb = 64
137+ nb = 64
138+ k = 256
139+ m = mb * block_size
140+ n = nb * block_size
141+ A_block = sp .random (mb , nb , dtype = "float32" , density = 0.05 , format = 'csr' )
142+ indptr = A_block .indptr
143+ indices = A_block .indices
144+ nnzb = A_block .nnz
145+ data = np .random .rand (nnzb , block_size , block_size )
146+ A = sp .bsr_matrix ((data , indices , indptr ), shape = (m , n ))
147+ x = np .random .rand (n , k ).astype ("float32" )
148+ y_ground_truth = A * x
149+ y = np .zeros ((m * k ,)).astype ("float32" )
150+
151+ # specialize function
152+ _ , _ , _ , _ , _ , MB , NB , K , BLOCK_SIZE , NNZB = bsrmm_tir .params
153+ sch = tir .Schedule (
154+ bsrmm_tir .specialize (
155+ {MB : mb , NB : nb , K : k , BLOCK_SIZE : block_size , NNZB : nnzb }
156+ )
157+ )
158+ blk_outer = sch .get_block ("spmm_outer" )
159+ io , ii , ji , k = sch .get_loops (blk_outer )
160+ sch .unroll (ii )
161+ sch .unroll (ji )
162+ sch .bind (io , "blockIdx.x" )
163+ sch .bind (k , "threadIdx.x" )
164+
165+ # convert numpy tensor to tvm ndarray
166+ A_indptr = tvm .nd .array (indptr .astype ("int32" ), device = tvm .cuda (0 ))
167+ A_indices = tvm .nd .array (indices .astype ("int32" ), device = tvm .cuda (0 ))
168+ A_data = tvm .nd .array (
169+ data .reshape (- 1 ).astype ("float32" ), device = tvm .cuda (0 ))
170+ X_nd = tvm .nd .array (x .reshape (- 1 ), device = tvm .cuda (0 ))
171+ Y_nd = tvm .nd .array (y , device = tvm .cuda (0 ))
172+
173+ # build function
174+ f = tvm .build (sch .mod , target = "cuda" )
175+ f (A_data , X_nd , Y_nd , A_indptr , A_indices )
176+
177+ # assertion
178+ assert np .allclose (y_ground_truth .reshape (- 1 ), Y_nd .numpy ())
179+
180+
181+ def test_ellmm ():
182+ # generate random input
183+ nnz_cols = 64
184+ m = 4096
185+ n = 4096
186+ k = 256
187+ nnz = nnz_cols * m
188+ indptr = np .arange (0 , (m + 1 ) * nnz_cols , nnz_cols )
189+ indices = np .random .randint (0 , n , size = (nnz ,))
190+ data = np .random .rand (nnz )
191+ A = sp .csr_matrix ((data , indices , indptr ), shape = (m , n ))
192+ x = np .random .rand (n , k ).astype ("float32" )
193+ y_ground_truth = A * x
194+ y = np .zeros ((m * k ,)).astype ("float32" )
195+ # specialize function
196+ _ , _ , _ , _ , M , N , K , NNZ_COLS = ellmm_tir .params
197+ sch = tir .Schedule (
198+ ellmm_tir .specialize (
199+ {M : m , N : n , K : k , NNZ_COLS : nnz_cols }
200+ )
201+ )
202+ blk = sch .get_block ("spmm" )
203+ i , j , k = sch .get_loops (blk )
204+ sch .bind (i , "blockIdx.x" )
205+ sch .bind (k , "threadIdx.x" )
206+ sch .unroll (j )
207+
208+ # convert numpy tensor to tvm ndarray
209+ A_indices = tvm .nd .array (indices .astype ("int32" ), device = tvm .cuda (0 ))
210+ A_data = tvm .nd .array (data .astype ("float32" ), device = tvm .cuda (0 ))
211+ X_nd = tvm .nd .array (x .reshape (- 1 ), device = tvm .cuda (0 ))
212+ Y_nd = tvm .nd .array (y , device = tvm .cuda (0 ))
213+
214+ # build function
215+ f = tvm .build (sch .mod , target = "cuda" )
216+ f (A_data , X_nd , Y_nd , A_indices )
217+
218+ # assertion
219+ assert np .allclose (y_ground_truth .reshape (- 1 ), Y_nd .numpy ())
220+
221+
222+ def test_bmm ():
223+ # TODO(zihao)
224+ pass
225+
226+
91227if __name__ == "__main__" :
92- test_csrmm ()
228+ test_csrmm ()
229+ test_bsrmm ()
230+ test_ellmm ()
231+ test_bmm ()
0 commit comments