1+ # Licensed to the Apache Software Foundation (ASF) under one
2+ # or more contributor license agreements. See the NOTICE file
3+ # distributed with this work for additional information
4+ # regarding copyright ownership. The ASF licenses this file
5+ # to you under the Apache License, Version 2.0 (the
6+ # "License"); you may not use this file except in compliance
7+ # with the License. You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing,
12+ # software distributed under the License is distributed on an
13+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+ # KIND, either express or implied. See the License for the
15+ # specific language governing permissions and limitations
16+ # under the License.
17+ import tvm
18+ import tvm .tir as tir
19+ import scipy .sparse as sp
20+ import numpy as np
21+ from tvm .script import tir as T
22+
23+
24+ @T .prim_func
25+ def csrmm (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle ) -> None :
26+ n = T .var ("int32" )
27+ m = T .var ("int32" )
28+ k = T .var ("int32" )
29+ nnz = T .var ("int32" )
30+ I = T .dense_fixed (m )
31+ J = T .sparse_variable ((n , m + 1 , nnz ), (indptr , indices ), "int32" )
32+ K = T .dense_fixed (k )
33+ A = T .match_sparse_buffer (a , (I , J ), nnz , "float32" )
34+ B = T .match_sparse_buffer (b , (T .to_dense (J ), K ), n * k , "float32" )
35+ C = T .match_sparse_buffer (c , (I , K ), m * k , "float32" )
36+ with T .iter ([T .cord (I ), T .cord (J ), T .cord (K )], "SRS" , "csrmm" ) as [vi , vj , vk ]:
37+ with T .init ():
38+ C [vi , vk ] = 0.0
39+ C [vi , vk ] = C [vi , vk ] + A [vi , vj ] * B [vj , vk ]
40+
41+
42+ @T .prim_func
43+ def csrmm_tir (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle ) -> None :
44+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
45+ n = T .var ("int32" )
46+ m = T .var ("int32" )
47+ k = T .var ("int32" )
48+ nnz = T .var ("int32" )
49+ A_data = T .match_buffer (a , (nnz ,), "float32" )
50+ B = T .match_buffer (b , (n , k ), "float32" )
51+ C = T .match_buffer (c , (m , k ), "float32" )
52+ A_indptr = T .match_buffer (indptr , (m + 1 ,), "int32" )
53+ A_indices = T .match_buffer (indices , (nnz ,), "int32" )
54+ for i , k in T .grid (m , k ):
55+ with T .block ("spmm_outer" ):
56+ vi , vk = T .axis .remap ("SS" , [i , k ])
57+ with T .init ():
58+ C [vi , vk ] = 0.
59+ for j in T .serial (0 , A_indptr [vi + 1 ] - A_indptr [vi ]):
60+ with T .block ("spmm_inner" ):
61+ vj = T .axis .R (n , j + A_indptr [vi ])
62+ C [vi , vk ] = C [vi , vk ] + A_data [vj ] * B [A_indices [vj ], vk ]
63+
64+
65+ def test_csrmm ():
66+ # generate random input
67+ A = sp .random (4096 , 4096 , dtype = "float32" , density = 0.0125 , format = 'csr' )
68+ x = np .random .rand (4096 , 256 ).astype ("float32" )
69+ y_ground_truth = A * x
70+ y = np .zeros ((4096 , 256 )).astype ("float32" )
71+
72+ # specialize function
73+ sch = tir .Schedule (csrmm_tir )
74+ blk_outer = sch .get_block ("spmm_outer" )
75+ i , k = sch .get_loops (blk_outer )
76+ sch .bind (i , "blockIdx.x" )
77+ sch .bind (k , "threadIdx.x" )
78+
79+ # convert numpy tensor to tvm ndarray
80+ A_indptr = tvm .nd .array (A .indptr .astype ("int32" ), device = tvm .cuda (0 ))
81+ A_indices = tvm .nd .array (A .indices .astype ("int32" ), device = tvm .cuda (0 ))
82+ A_data = tvm .nd .array (A .data .astype ("float32" ), device = tvm .cuda (0 ))
83+ X_nd = tvm .nd .array (x , device = tvm .cuda (0 ))
84+ Y_nd = tvm .nd .array (y , device = tvm .cuda (0 ))
85+
86+ # build function
87+ f = tvm .build (sch .mod , target = 'cuda' )
88+ f (A_data , X_nd , Y_nd , A_indptr , A_indices )
89+
90+ assert np .allclose (y_ground_truth , Y_nd .numpy ())
91+
92+
93+ if __name__ == "__main__" :
94+ test_csrmm ()
0 commit comments