23
23
from ..utils import get_const_tuple
24
24
25
25
26
- def sparse_dense (data , weight_data , weight_indices , weight_indptr ):
26
+ def sparse_dense_v2 (data , weight_data , weight_indices , weight_indptr ):
27
27
"""
28
28
Computes sparse-dense matrix multiplication of `data` and
29
29
`(weight_data, weight_indices, weight_indptr).T`
@@ -52,13 +52,104 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
52
52
"""
53
53
assert len (weight_data .shape ) in (1 , 3 )
54
54
if len (weight_data .shape ) == 1 :
55
- func = _sparse_dense_csrmm
55
+ func = _sparse_dense_csrmm_v2
56
56
if len (weight_data .shape ) == 3 :
57
- func = _sparse_dense_bsrmm
57
+ func = _sparse_dense_bsrmm_v2
58
58
return func (data , weight_data , weight_indices , weight_indptr )
59
59
60
60
61
- def _sparse_dense_csrmm (data , weight_data , weight_indices , weight_indptr ):
61
+ def sparse_dense_v1 (data_data , data_indices , data_indptr , weight ):
62
+ """
63
+ Computes sparse-dense matrix multiplication of
64
+ `(data_data, data_indices, data_indptr)` and `weight.T`
65
+
66
+ Parameters
67
+ ----------
68
+ data_data:
69
+ 1-D with shape [nnz] (CSR) or
70
+ 3-D with shape [num_blocks, bs_r, bs_c] (BSR)
71
+
72
+ data_indices:
73
+ 1-D with shape [nnz] (CSR) or
74
+ 1-D with shape [num_blocks] (BSR)
75
+
76
+ data_indptr:
77
+ 1-D with shape [M + 1] (CSR) or
78
+ 1-D with shape [(M + 1) // bs_r] (BSR)
79
+
80
+ weight:
81
+ 2-D with shape [N, K], float32
82
+
83
+ Returns
84
+ -------
85
+ output : tvm.te.Tensor
86
+ 2-D with shape [M, N]
87
+ """
88
+ assert len (data_data .shape ) in (1 , 3 )
89
+ if len (data_data .shape ) == 1 :
90
+ func = _sparse_dense_csrmm_v1
91
+ if len (data_data .shape ) == 3 :
92
+ func = _sparse_dense_bsrmm_v1
93
+ return func (data_data , data_indices , data_indptr , weight )
94
+
95
+
96
+ # pylint: disable=no-else-return,inconsistent-return-statements
97
+ def sparse_dense (dense_data , sparse_data , sparse_indices , sparse_indptr , sparse_lhs = False ):
98
+ """
99
+ Computes sparse-dense matrix multiplication of `data` and
100
+ `(weight_data, weight_indices, weight_indptr).T`, if sparse_lhs=False
101
+ or
102
+ Computes sparse-dense matrix multiplication of
103
+ `(data_data, data_indices, data_indptr)` and `weight.T`, if sparse_lhs=True
104
+
105
+ Parameters
106
+ ----------
107
+ dense_data : tvm.te.Tensor
108
+ 2-D with shape [M, K], float32
109
+
110
+ sparse_data : tvm.te.Tensor
111
+ 1-D with shape [nnz] (CSR) or
112
+ 3-D with shape [num_blocks, bs_r, bs_c] (BSR)
113
+
114
+ sparse_indices : tvm.te.Tensor
115
+ 1-D with shape [nnz] (CSR) or
116
+ 1-D with shape [num_blocks] (BSR)
117
+
118
+ sparse_indptr : tvm.te.Tensor
119
+ 1-D with shape [N + 1] (CSR) or
120
+ 1-D with shape [(N + 1) // bs_r] (BSR)
121
+
122
+ sparse_lhs : bool, optional
123
+ Indicates whether lhs or rhs matrix is sparse. Default value is False.
124
+
125
+ Returns
126
+ -------
127
+ output : tvm.te.Tensor
128
+ 2-D with shape [M, N]
129
+ """
130
+ if sparse_lhs :
131
+ return sparse_dense_v1 (sparse_data , sparse_indices , sparse_indptr , dense_data )
132
+ else :
133
+ return sparse_dense_v2 (dense_data , sparse_data , sparse_indices , sparse_indptr )
134
+
135
+
136
+ def _sparse_dense_csrmm_v1 (data_data , data_indices , data_indptr , weight ):
137
+ oshape = (get_const_tuple (data_indptr .shape )[0 ] - 1 , get_const_tuple (weight .shape )[0 ])
138
+
139
+ def f (row , i ):
140
+ row_start = data_indptr [row ]
141
+ row_end = data_indptr [row + 1 ]
142
+ row_elems = row_end - row_start
143
+ elem_idx = te .reduce_axis ((0 , row_elems ), name = "elem_idx" )
144
+ elem = row_start + elem_idx
145
+ a_val = data_data [elem ]
146
+ weight_val = weight [i , data_indices [elem ]]
147
+ return te .sum (a_val * weight_val , axis = elem_idx )
148
+
149
+ return te .compute (oshape , f , tag = "sparse_dense_csrmm_v1" )
150
+
151
+
152
+ def _sparse_dense_csrmm_v2 (data , weight_data , weight_indices , weight_indptr ):
62
153
oshape = (get_const_tuple (data .shape )[0 ], get_const_tuple (weight_indptr .shape )[0 ] - 1 )
63
154
64
155
def f (i , row ):
@@ -71,10 +162,41 @@ def f(i, row):
71
162
weight_val = data [i , weight_indices [elem ]]
72
163
return te .sum (a_val * weight_val , axis = elem_idx )
73
164
74
- return te .compute (oshape , f , tag = "sparse_dense_csrmm " )
165
+ return te .compute (oshape , f , tag = "sparse_dense_csrmm_v2 " )
75
166
76
167
77
- def _sparse_dense_bsrmm (data , weight_data , weight_indices , weight_indptr ):
168
+ def _sparse_dense_bsrmm_v1 (data_data , data_indices , data_indptr , weight ):
169
+ (m , _ ) = get_const_tuple (weight .shape )
170
+ (_ , bs_r , bs_c ) = get_const_tuple (data_data .shape )
171
+ (num_blocks_plus_1 ,) = get_const_tuple (data_indptr .shape )
172
+ num_blocks = num_blocks_plus_1 - 1
173
+
174
+ def _compute_block (nb_j , j , i ):
175
+ row_start = data_indptr [nb_j ]
176
+ row_end = data_indptr [nb_j + 1 ]
177
+ row_elems = row_end - row_start
178
+ elem_idx = te .reduce_axis ((0 , row_elems ), name = "elem_idx" )
179
+ block_offset = row_start + elem_idx
180
+ c = te .reduce_axis ((0 , bs_c ), name = "c" )
181
+ block_j = data_indices [block_offset ]
182
+ block_ij_val = data_data [block_offset ][j ][c ]
183
+ x_val = weight [i , bs_c * block_j + c ]
184
+ return te .sum (block_ij_val * x_val , axis = [elem_idx , c ])
185
+
186
+ idxd = tvm .tir .indexdiv
187
+ idxm = tvm .tir .indexmod
188
+
189
+ bsrmm_block = te .compute (
190
+ (num_blocks , bs_r , m ), _compute_block , tag = "sparse_dense_bsrmm_block_v1"
191
+ )
192
+ return te .compute (
193
+ (num_blocks * bs_r , m ),
194
+ lambda m , n : bsrmm_block [idxd (m , bs_r ), idxm (m , bs_r ), n ],
195
+ tag = "sparse_dense_bsrmm_v1" ,
196
+ )
197
+
198
+
199
+ def _sparse_dense_bsrmm_v2 (data , weight_data , weight_indices , weight_indptr ):
78
200
(m , _ ) = get_const_tuple (data .shape )
79
201
(_ , bs_r , bs_c ) = get_const_tuple (weight_data .shape )
80
202
(num_blocks_plus_1 ,) = get_const_tuple (weight_indptr .shape )
@@ -95,11 +217,13 @@ def _compute_block(i, nb_j, j):
95
217
idxd = tvm .tir .indexdiv
96
218
idxm = tvm .tir .indexmod
97
219
98
- bsrmm_block = te .compute ((m , num_blocks , bs_r ), _compute_block , tag = "sparse_dense_bsrmm_block" )
220
+ bsrmm_block = te .compute (
221
+ (m , num_blocks , bs_r ), _compute_block , tag = "sparse_dense_bsrmm_block_v2"
222
+ )
99
223
return te .compute (
100
224
(m , num_blocks * bs_r ),
101
225
lambda m , n : bsrmm_block [m , idxd (n , bs_r ), idxm (n , bs_r )],
102
- tag = "sparse_dense_bsrmm " ,
226
+ tag = "sparse_dense_bsrmm_v2 " ,
103
227
)
104
228
105
229
0 commit comments