@@ -92,6 +92,25 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int
9292 B [A_indices [vi * NNZ_COLS + vj ] * K + vk ]
9393
9494
95+ @T .prim_func
96+ def sddmm_tir (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , M : T .int32 , N : T .int32 , K : T .int32 , NNZ : T .int32 ) -> None :
97+ T .func_attr ({"global_symbol" : "main" , "tir.noalis" : True })
98+ A = T .match_buffer (a , (M * K ,), "float32" )
99+ B = T .match_buffer (b , (N * K ,), "float32" )
100+ C_data = T .match_buffer (c , (NNZ ,), "float32" )
101+ C_indptr = T .match_buffer (indptr , (M + 1 ,), "int32" )
102+ C_indices = T .match_buffer (indices , (NNZ ,), "int32" )
103+ for ij , k in T .grid (NNZ , K ):
104+ with T .block ("sddmm" ):
105+ vij , vk = T .axis .remap ("SR" , [ij , k ])
106+ T .reads ([A [0 : M * K ], B [0 : N * K ], C_data [vij ], C_indices [vij ], C_indptr [0 : M + 1 ]])
107+ T .writes ([C_data [vij ]])
108+ with T .init ():
109+ C_data [vij ] = 0.
110+ C_data [vij ] = C_data [vij ] + \
111+ A [T .lower_bound (C_indptr .data , vij , 0 , M + 1 ) * K + vk ] * B [C_indices [vij ] * K + vk ]
112+
113+
95114def test_csrmm ():
96115 # generate random input
97116 m = 4096
@@ -219,6 +238,50 @@ def test_ellmm():
219238 assert np .allclose (y_ground_truth .reshape (- 1 ), Y_nd .numpy ())
220239
221240
241+ def test_sddmm ():
242+ # generate random input
243+ m = 4096
244+ n = 4096
245+ k = 256
246+ C = sp .random (m , n , dtype = "float32" , density = 0.0125 , format = 'csr' )
247+ indptr = C .indptr
248+ indices = C .indices
249+ C_coo = C .tocoo ()
250+ nnz = C .nnz
251+ x = np .random .rand (m , k ).astype ("float32" )
252+ y = np .random .rand (n , k ).astype ("float32" )
253+ z_ground_truth = np .matmul (x , y .transpose ())[C_coo .row , C_coo .col ]
254+ z = np .zeros ((nnz ,)).astype ("float32" )
255+
256+ # specialize function
257+ _ , _ , _ , _ , _ , M , N , K , NNZ = sddmm_tir .params
258+ sch = tir .Schedule (
259+ sddmm_tir .specialize (
260+ {M : m , N : n , K : k , NNZ : nnz }
261+ )
262+ )
263+ blk = sch .get_block ("sddmm" )
264+ ij , k = sch .get_loops (blk )
265+ #sch.decompose_reduction(blk, ij)
266+ sch .bind (ij , "blockIdx.x" )
267+ ko , ki = sch .split (k , [None , 1 ])
268+ sch .bind (ki , "threadIdx.x" )
269+
270+ # convert numpy tensor to tvm ndarray
271+ C_indices = tvm .nd .array (indices .astype ("int32" ), device = tvm .cuda (0 ))
272+ C_indptr = tvm .nd .array (indptr .astype ("int32" ), device = tvm .cuda (0 ))
273+ X_nd = tvm .nd .array (x .reshape (- 1 ), device = tvm .cuda (0 ))
274+ Y_nd = tvm .nd .array (y .reshape (- 1 ), device = tvm .cuda (0 ))
275+ C_data = tvm .nd .array (z , device = tvm .cuda (0 ))
276+
277+ # build function
278+ f = tvm .build (sch .mod ['main' ], target = "cuda" )
279+ f (X_nd , Y_nd , C_data , C_indptr , C_indices )
280+
281+ # assertion
282+ np .allclose (z_ground_truth , C_data .numpy ())
283+
284+
222285def test_bmm ():
223286 # TODO(zihao)
224287 pass
@@ -228,4 +291,5 @@ def test_bmm():
228291 test_csrmm ()
229292 test_bsrmm ()
230293 test_ellmm ()
294+ test_sddmm ()
231295 test_bmm ()
0 commit comments