Skip to content

Commit 862655b

Browse files
author
ANSHUMAN TRIPATHY
authored
[TOPI] sparse_dense Op sparse_data input added (#6889)
* [TOPI] sparse_dense op sparse_data input added * [1] clang issue resolved * [2] python format resolved * [3] lint error resolved * [4] Review comments handled * [5] Lint error resolved * [6] Review comments handled * [7] Review comments handled * [8] Review comments handled
1 parent 054466b commit 862655b

File tree

11 files changed

+321
-64
lines changed

11 files changed

+321
-64
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,15 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
938938

939939
/*! \brief Attributes for sparse_dense operator */
940940
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
941-
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
941+
bool sparse_lhs;
942+
943+
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {
944+
TVM_ATTR_FIELD(sparse_lhs)
945+
.set_default(false)
946+
.describe(
947+
"Indicate whether sparse matrix is multiplied on the right or the left. If true, then "
948+
"the operation is S * D^T (D dense, S sparse). If false, the operation is D * S^T");
949+
}
942950
};
943951

944952
/*! \brief Attributes for sparse_transpose operator */

python/tvm/relay/frontend/tensorflow.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -926,28 +926,45 @@ def _impl(inputs, attr, params, mod):
926926

927927
data = inputs[3]
928928

929+
# By default, in tensorflow the first input ,i.e., data is sparse
930+
sparse_lhs = True
931+
932+
# If both are true means First input was dense and second was sparse
933+
if attr.get("adjoint_a") and attr.get("adjoint_b"):
934+
sparse_lhs = False
935+
929936
rows = [x[0] for x in indices_tensor]
930937
cols = [x[1] for x in indices_tensor]
931938

932939
# Create scipy sparse Tensor(CSR)
933940
weight_sp = csr_matrix(
934941
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
935942
)
936-
weight_sp = csr_matrix(weight_sp.transpose())
943+
944+
if sparse_lhs:
945+
data = _op.transpose(data)
946+
else:
947+
weight_sp = csr_matrix(weight_sp.transpose())
937948

938949
weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
939950
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
940951
weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)
941952

942-
ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs])
953+
ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs], sparse_lhs)
943954

944-
# If both are true means First input was dense and second was sparse
945-
# TODO(ANSHUMAN87): Support other adjoint option too
946-
if attr.get("adjoint_a") and attr.get("adjoint_b"):
955+
if not sparse_lhs:
947956
ret = _op.transpose(ret)
948-
else:
957+
958+
# Case 1. If both are true means first input was dense and second was sparse
959+
# Case 2. If both are false means first input was sparse and second was dense
960+
# TODO(ANSHUMAN87): Support other adjoint option too
961+
if not (
962+
(attr.get("adjoint_a") and attr.get("adjoint_b"))
963+
or ((not attr.get("adjoint_a")) and (not attr.get("adjoint_b")))
964+
):
949965
raise tvm.error.OpAttributeUnImplemented(
950966
"Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True"
967+
"or with adjoint_a=False and adjoint_b=False"
951968
" is supported, but adjoint_a={} and adjoint_b={} was supplied.".format(
952969
attr.get("adjoint_a"), attr.get("adjoint_b")
953970
)

python/tvm/relay/op/nn/_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def compute_fifo_buffer(attrs, inputs, out_type):
6969
@reg.register_compute("nn.sparse_dense")
7070
def compute_sparse_dense(attrs, inputs, out_type):
7171
"""Compute definition of sparse_dense"""
72-
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
72+
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])]
7373

7474

7575
reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy)

python/tvm/relay/op/nn/nn.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,17 +1993,27 @@ def batch_matmul(x, y):
19931993
return _make.batch_matmul(x, y)
19941994

19951995

1996-
def sparse_dense(data, weight):
1996+
# pylint: disable=no-else-return,inconsistent-return-statements
1997+
def sparse_dense(dense_mat, sparse_mat, sparse_lhs=False):
19971998
r"""
1998-
Computes the matrix multiplication of `data` and `weight`, where `data` is
1999-
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
1999+
Computes the matrix multiplication of `dense_mat` and `sparse_mat`, where `dense_mat` is
2000+
a dense matrix and `sparse_mat` is a sparse (either BSR or CSR) namedtuple with
20002001
fields `data`, `indices`, and `indptr`.
20012002
2002-
.. math::
2003+
\if sparse_lhs=False:
2004+
.. math::
2005+
2006+
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
2007+
= \mbox{matmul}(D, \mbox{as_dense}(S)^T)[m, n]
20032008
2004-
\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
2009+
\if sparse_lhs=True:
2010+
.. math::
20052011
2006-
where `as_dense` returns dense equivalent of the given sparse matrix.
2012+
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
2013+
= \mbox{matmul}(\mbox{as_dense}(S), (D)^T)[m, n]
2014+
2015+
where `as_dense` returns dense equivalent of the given S(sparse matrix)
2016+
while performing matmul with given D(dense matrix).
20072017
20082018
See
20092019
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
@@ -2013,20 +2023,28 @@ def sparse_dense(data, weight):
20132023
20142024
Parameters
20152025
----------
2016-
data : tvm.relay.Expr
2017-
The input data for the matrix multiplication
2026+
dense_mat : tvm.relay.Expr
2027+
The input dense matrix for the matrix multiplication
20182028
2019-
weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
2020-
The sparse weight matrix for the matrix multiplication.
2029+
sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
2030+
The input sparse matrix for the matrix multiplication.
2031+
2032+
sparse_lhs : bool, optional
2033+
Indicates whether lhs or rhs matrix is sparse. Default value is False.
20212034
20222035
Returns
20232036
-------
20242037
result: tvm.relay.Expr
20252038
The computed result.
20262039
"""
2027-
if hasattr(weight, "indices"):
2028-
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
2029-
return _make.sparse_dense(data, weight[0], weight[1], weight[2])
2040+
if hasattr(sparse_mat, "indices"):
2041+
return _make.sparse_dense(
2042+
dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr, sparse_lhs
2043+
)
2044+
else:
2045+
return _make.sparse_dense(
2046+
dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2], sparse_lhs
2047+
)
20302048

20312049

20322050
def sparse_transpose(x):

python/tvm/relay/op/strategy/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def wrap_compute_sparse_dense(topi_compute):
737737
"""wrap sparse dense topi compute"""
738738

739739
def _compute_sparse_dense(attrs, inputs, out_type):
740-
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
740+
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])]
741741

742742
return _compute_sparse_dense
743743

python/tvm/topi/cuda/sparse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ def schedule_sparse_dense(outs):
6565
# pylint:disable=invalid-name
6666
s = te.create_schedule([x.op for x in outs])
6767

68+
# TODO(ANSHUMAN87): Add for sparse_dense_bsrmm_v1 also
6869
def _callback(op):
69-
if op.tag == "sparse_dense_bsrmm":
70+
if op.tag == "sparse_dense_bsrmm_v2":
7071
y_bsrmm = op.input_tensors[0]
71-
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
72+
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2"
7273
out = s.outputs[0].output(0)
7374

7475
if op not in s.outputs:
@@ -362,6 +363,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
362363
sparse_dense implementation for one that operates on a padded matrix. We
363364
also padd the matrix.
364365
"""
366+
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
365367
if (
366368
isinstance(inputs[1], relay.Constant)
367369
and isinstance(inputs[2], relay.Constant)

python/tvm/topi/nn/sparse.py

Lines changed: 132 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..utils import get_const_tuple
2424

2525

26-
def sparse_dense(data, weight_data, weight_indices, weight_indptr):
26+
def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr):
2727
"""
2828
Computes sparse-dense matrix multiplication of `data` and
2929
`(weight_data, weight_indices, weight_indptr).T`
@@ -52,13 +52,104 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
5252
"""
5353
assert len(weight_data.shape) in (1, 3)
5454
if len(weight_data.shape) == 1:
55-
func = _sparse_dense_csrmm
55+
func = _sparse_dense_csrmm_v2
5656
if len(weight_data.shape) == 3:
57-
func = _sparse_dense_bsrmm
57+
func = _sparse_dense_bsrmm_v2
5858
return func(data, weight_data, weight_indices, weight_indptr)
5959

6060

61-
def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
61+
def sparse_dense_v1(data_data, data_indices, data_indptr, weight):
62+
"""
63+
Computes sparse-dense matrix multiplication of
64+
`(data_data, data_indices, data_indptr)` and `weight.T`
65+
66+
Parameters
67+
----------
68+
data_data:
69+
1-D with shape [nnz] (CSR) or
70+
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
71+
72+
data_indices:
73+
1-D with shape [nnz] (CSR) or
74+
1-D with shape [num_blocks] (BSR)
75+
76+
data_indptr:
77+
1-D with shape [M + 1] (CSR) or
78+
1-D with shape [(M + 1) // bs_r] (BSR)
79+
80+
weight:
81+
2-D with shape [N, K], float32
82+
83+
Returns
84+
-------
85+
output : tvm.te.Tensor
86+
2-D with shape [M, N]
87+
"""
88+
assert len(data_data.shape) in (1, 3)
89+
if len(data_data.shape) == 1:
90+
func = _sparse_dense_csrmm_v1
91+
if len(data_data.shape) == 3:
92+
func = _sparse_dense_bsrmm_v1
93+
return func(data_data, data_indices, data_indptr, weight)
94+
95+
96+
# pylint: disable=no-else-return,inconsistent-return-statements
97+
def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_lhs=False):
98+
"""
99+
Computes sparse-dense matrix multiplication of `data` and
100+
`(weight_data, weight_indices, weight_indptr).T`, if sparse_lhs=False
101+
or
102+
Computes sparse-dense matrix multiplication of
103+
`(data_data, data_indices, data_indptr)` and `weight.T`, if sparse_lhs=True
104+
105+
Parameters
106+
----------
107+
dense_data : tvm.te.Tensor
108+
2-D with shape [M, K], float32
109+
110+
sparse_data : tvm.te.Tensor
111+
1-D with shape [nnz] (CSR) or
112+
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
113+
114+
sparse_indices : tvm.te.Tensor
115+
1-D with shape [nnz] (CSR) or
116+
1-D with shape [num_blocks] (BSR)
117+
118+
sparse_indptr : tvm.te.Tensor
119+
1-D with shape [N + 1] (CSR) or
120+
1-D with shape [(N + 1) // bs_r] (BSR)
121+
122+
sparse_lhs : bool, optional
123+
Indicates whether lhs or rhs matrix is sparse. Default value is False.
124+
125+
Returns
126+
-------
127+
output : tvm.te.Tensor
128+
2-D with shape [M, N]
129+
"""
130+
if sparse_lhs:
131+
return sparse_dense_v1(sparse_data, sparse_indices, sparse_indptr, dense_data)
132+
else:
133+
return sparse_dense_v2(dense_data, sparse_data, sparse_indices, sparse_indptr)
134+
135+
136+
def _sparse_dense_csrmm_v1(data_data, data_indices, data_indptr, weight):
137+
oshape = (get_const_tuple(data_indptr.shape)[0] - 1, get_const_tuple(weight.shape)[0])
138+
139+
def f(row, i):
140+
row_start = data_indptr[row]
141+
row_end = data_indptr[row + 1]
142+
row_elems = row_end - row_start
143+
elem_idx = te.reduce_axis((0, row_elems), name="elem_idx")
144+
elem = row_start + elem_idx
145+
a_val = data_data[elem]
146+
weight_val = weight[i, data_indices[elem]]
147+
return te.sum(a_val * weight_val, axis=elem_idx)
148+
149+
return te.compute(oshape, f, tag="sparse_dense_csrmm_v1")
150+
151+
152+
def _sparse_dense_csrmm_v2(data, weight_data, weight_indices, weight_indptr):
62153
oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1)
63154

64155
def f(i, row):
@@ -71,10 +162,41 @@ def f(i, row):
71162
weight_val = data[i, weight_indices[elem]]
72163
return te.sum(a_val * weight_val, axis=elem_idx)
73164

74-
return te.compute(oshape, f, tag="sparse_dense_csrmm")
165+
return te.compute(oshape, f, tag="sparse_dense_csrmm_v2")
75166

76167

77-
def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
168+
def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight):
169+
(m, _) = get_const_tuple(weight.shape)
170+
(_, bs_r, bs_c) = get_const_tuple(data_data.shape)
171+
(num_blocks_plus_1,) = get_const_tuple(data_indptr.shape)
172+
num_blocks = num_blocks_plus_1 - 1
173+
174+
def _compute_block(nb_j, j, i):
175+
row_start = data_indptr[nb_j]
176+
row_end = data_indptr[nb_j + 1]
177+
row_elems = row_end - row_start
178+
elem_idx = te.reduce_axis((0, row_elems), name="elem_idx")
179+
block_offset = row_start + elem_idx
180+
c = te.reduce_axis((0, bs_c), name="c")
181+
block_j = data_indices[block_offset]
182+
block_ij_val = data_data[block_offset][j][c]
183+
x_val = weight[i, bs_c * block_j + c]
184+
return te.sum(block_ij_val * x_val, axis=[elem_idx, c])
185+
186+
idxd = tvm.tir.indexdiv
187+
idxm = tvm.tir.indexmod
188+
189+
bsrmm_block = te.compute(
190+
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1"
191+
)
192+
return te.compute(
193+
(num_blocks * bs_r, m),
194+
lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n],
195+
tag="sparse_dense_bsrmm_v1",
196+
)
197+
198+
199+
def _sparse_dense_bsrmm_v2(data, weight_data, weight_indices, weight_indptr):
78200
(m, _) = get_const_tuple(data.shape)
79201
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
80202
(num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape)
@@ -95,11 +217,13 @@ def _compute_block(i, nb_j, j):
95217
idxd = tvm.tir.indexdiv
96218
idxm = tvm.tir.indexmod
97219

98-
bsrmm_block = te.compute((m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block")
220+
bsrmm_block = te.compute(
221+
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block_v2"
222+
)
99223
return te.compute(
100224
(m, num_blocks * bs_r),
101225
lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
102-
tag="sparse_dense_bsrmm",
226+
tag="sparse_dense_bsrmm_v2",
103227
)
104228

105229

0 commit comments

Comments
 (0)