Skip to content

Commit d5d844c

Browse files
ANSHUMAN TRIPATHYANSHUMAN TRIPATHY
authored andcommitted
[1] Review comments handled
1 parent eda4d6e commit d5d844c

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -904,8 +904,8 @@ def _impl(inputs, attr, params, mod):
904904

905905

906906
def _sparse_tensor_dense_matmul():
907-
# Sparse utility from Numpy
908-
from scipy import sparse
907+
# Sparse utility from scipy
908+
from scipy.sparse import csr_matrix
909909

910910
def _impl(inputs, attr, params, mod):
911911
assert len(inputs) == 4, "There should be 4 input tensors"
@@ -919,11 +919,11 @@ def _impl(inputs, attr, params, mod):
919919
rows = [x[0] for x in indices_tensor]
920920
cols = [x[1] for x in indices_tensor]
921921

922-
# Create Numpy sparse Tensor(CSR)
923-
weight_sp = sparse.csr_matrix(
922+
# Create scipy sparse Tensor(CSR)
923+
weight_sp = csr_matrix(
924924
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
925925
)
926-
weight_sp = sparse.csr_matrix(weight_sp.transpose())
926+
weight_sp = csr_matrix(weight_sp.transpose())
927927

928928
weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
929929
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
@@ -935,6 +935,8 @@ def _impl(inputs, attr, params, mod):
935935
# TODO: Support other adjoint option too
936936
if attr.get("adjoint_a") and attr.get("adjoint_b"):
937937
ret = _op.transpose(ret)
938+
else:
939+
raise tvm.error.OpAttributeUnImplemented("Adjoint option is not supported yet.")
938940

939941
return ret
940942

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,9 +1762,7 @@ def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=Fal
17621762
for adjoint_a in [False]:
17631763
for adjoint_b in [False]:
17641764
with tf.Graph().as_default():
1765-
A_sp = tf.sparse.SparseTensor(
1766-
indices=[[0, 0], [1, 2]], values=[4.0, 8.0], dense_shape=A_shape
1767-
)
1765+
A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape)
17681766
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
17691767

17701768
if flip:

0 commit comments

Comments
 (0)