Skip to content

Commit 575617f

Browse files
ANSHUMAN TRIPATHYANSHUMAN TRIPATHY
authored andcommitted
[Relay][Frontend] SparseTensorDenseMatMul support for Tensorflow
1 parent 0cdd285 commit 575617f

File tree

4 files changed

+94
-6
lines changed

4 files changed

+94
-6
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,40 @@ def _impl(inputs, attr, params, mod):
889889

890890
return _impl
891891

892+
def _sparse_tensor_dense_matmul():
893+
# Sparse utility from Numpy
894+
from scipy import sparse
895+
896+
def _impl(inputs, attr, params, mod):
897+
assert len(inputs) == 4, "There should be 4 input tensors"
898+
899+
indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
900+
values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
901+
dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()
902+
903+
data = inputs[3]
904+
905+
rows = [x[0] for x in indices_tensor]
906+
cols = [x[1] for x in indices_tensor]
907+
908+
# Create Numpy sparse Tensor(CSR)
909+
weight_sp = sparse.csr_matrix((values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()))
910+
weight_sp = sparse.csr_matrix(weight_sp.transpose())
911+
912+
weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
913+
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
914+
weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)
915+
916+
ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs])
917+
918+
# If both are true means First input was dense and second was sparse
919+
# TODO: Support other adjoint option too
920+
if attr.get("adjoint_a") and attr.get("adjoint_b"):
921+
ret = _op.transpose(ret)
922+
923+
return ret
924+
925+
return _impl
892926

893927
def _identity():
894928
def _impl(inputs, attr, params, mod):
@@ -2357,6 +2391,7 @@ def _impl(inputs, attr, params, mod):
23572391
"Softplus": _softplus(),
23582392
"SpaceToBatchND": _space_to_batch_nd(),
23592393
"SpaceToDepth": _space_to_depth(),
2394+
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
23602395
"Split": _split(False),
23612396
"SplitV": _split(True),
23622397
"Sqrt": AttrCvt("sqrt"),

python/tvm/relay/op/nn/nn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,7 +2046,7 @@ def sparse_transpose(x):
20462046
20472047
Parameters
20482048
----------
2049-
x : namedtuple.
2049+
x : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
20502050
The sparse weight matrix for the fast matrix transpose.
20512051
20522052
Returns
@@ -2055,7 +2055,9 @@ def sparse_transpose(x):
20552055
Tuple of output sparse tensor (same shape and format as input),
20562056
i.e. if CSR then output is in ([data, indices, indptr]) form
20572057
"""
2058-
return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
2058+
if hasattr(x, "indices"):
2059+
return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
2060+
return expr.TupleWrapper(_make.sparse_transpose(x[0], x[1], x[2]), 3)
20592061

20602062

20612063
def contrib_conv2d_winograd_without_weight_transform(

python/tvm/topi/cuda/sparse.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):
180180
assert (
181181
mb >= mi
182182
), "Number of block rows in dense matrix must be larger than warp size: {} vs {}.".format(
183-
warp_size, m
183+
warp_size, mb
184184
)
185185
mo = ceil_div(mb, mi)
186186
ni = 1 # TODO(tkonolige): how do I compute the number of warps per block?
@@ -367,9 +367,14 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
367367
and isinstance(inputs[2], relay.Constant)
368368
and isinstance(inputs[3], relay.Constant)
369369
):
370-
sparse_matrix = sp.bsr_matrix(
371-
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
372-
)
370+
if len(inputs[1].data.asnumpy().shape) == 1:
371+
sparse_matrix = sp.csr_matrix(
372+
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
373+
).tobsr()
374+
else :
375+
sparse_matrix = sp.bsr_matrix(
376+
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
377+
)
373378
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
374379
sparse_matrix = pad_sparse_matrix(sparse_matrix, warp_size)
375380
return relay.nn._make.sparse_dense_padded(

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,52 @@ def test_forward_batch_matmul():
17491749
_test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False)
17501750
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)
17511751

1752+
#######################################################################
1753+
# SparseTensorDenseMatMul
1754+
# ----------------------------------
1755+
1756+
1757+
def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False):
1758+
""" One iteration of sparse_dense_matmul """
1759+
1760+
#TODO: Support adjoint options too
1761+
for adjoint_a in [False]:
1762+
for adjoint_b in [False]:
1763+
with tf.Graph().as_default():
1764+
A_sp = tf.sparse.SparseTensor(indices=[[0, 0], [1, 2]], values=[4., 8.], dense_shape=A_shape)
1765+
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
1766+
1767+
if flip:
1768+
result = tf.sparse.sparse_dense_matmul(B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b)
1769+
else:
1770+
result = tf.sparse.sparse_dense_matmul(A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b)
1771+
1772+
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
1773+
1774+
#TODO: There is an issue in cuda scheduling for csr, work in progress
1775+
compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)
1776+
1777+
def test_forward_sparse_dense_matmul():
1778+
""" sparse_dense_matmul op test"""
1779+
###################################################################
1780+
#
1781+
# In order to create a SparseTensor, it requires 3 input as below:
1782+
# SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
1783+
#
1784+
# Above Sparse can be represented in Dense as below:
1785+
# [[1, 0, 0, 0]
1786+
# [0, 0, 2, 0]
1787+
# [0, 0, 0, 0]]
1788+
#
1789+
#------------------------------------------------------------------
1790+
1791+
#TODO: False case for flip need to be supported
1792+
#_test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 4], [4, 3], "float32")
1793+
_test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 5], [4, 3], "float32", True)
1794+
_test_sparse_dense_matmul([[0, 0], [1, 2]], [4., 8.], [3, 3], [3, 3], "float32", True)
1795+
_test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3., 6., 9.], [5, 5], [5, 5], "float32", True)
1796+
_test_sparse_dense_matmul([[0, 0], [1, 3], [4, 3]], [3., 6., 9.], [9, 5], [7, 9], "float32", True)
1797+
17521798

17531799
#######################################################################
17541800
# StridedSlice

0 commit comments

Comments
 (0)