Skip to content

Commit a4ed84f

Browse files
ANSHUMAN TRIPATHYANSHUMAN TRIPATHY
authored andcommitted
[8] Review comments handled
1 parent 4623979 commit a4ed84f

File tree

4 files changed

+65
-13
lines changed

4 files changed

+65
-13
lines changed

python/tvm/relay/op/nn/nn.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,17 +2000,17 @@ def sparse_dense(dense_mat, sparse_mat, sparse_lhs=False):
20002000
a dense matrix and `sparse_mat` is a sparse (either BSR or CSR) namedtuple with
20012001
fields `data`, `indices`, and `indptr`.
20022002
2003-
.. math::
2004-
2005-
if sparse_lhs=True
2003+
\if sparse_lhs=False:
2004+
.. math::
20062005
2007-
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
2008-
= \mbox{matmul}(D, \mbox{as_dense}(S)^T)[m, n]
2006+
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
2007+
= \mbox{matmul}(D, \mbox{as_dense}(S)^T)[m, n]
20092008
2010-
if sparse_lhs=False
2009+
\if sparse_lhs=True:
2010+
.. math::
20112011
2012-
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
2013-
= \mbox{matmul}(\mbox{as_dense}(S), (D)^T)[m, n]
2012+
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
2013+
= \mbox{matmul}(\mbox{as_dense}(S), (D)^T)[m, n]
20142014
20152015
where `as_dense` returns dense equivalent of the given S(sparse matrix)
20162016
while performing matmul with given D(dense matrix).

python/tvm/topi/cuda/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
363363
sparse_dense implementation for one that operates on a padded matrix. We
364364
also padd the matrix.
365365
"""
366-
# TODO(ANSHUMAN87): Handle for sparse_data case too
366+
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
367367
if (
368368
isinstance(inputs[1], relay.Constant)
369369
and isinstance(inputs[2], relay.Constant)

python/tvm/topi/nn/sparse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def f(i, row):
166166

167167

168168
def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight):
169-
(k, _) = get_const_tuple(weight.shape)
169+
(m, _) = get_const_tuple(weight.shape)
170170
(_, bs_r, bs_c) = get_const_tuple(data_data.shape)
171171
(num_blocks_plus_1,) = get_const_tuple(data_indptr.shape)
172172
num_blocks = num_blocks_plus_1 - 1
@@ -187,11 +187,11 @@ def _compute_block(nb_j, j, i):
187187
idxm = tvm.tir.indexmod
188188

189189
bsrmm_block = te.compute(
190-
(num_blocks, bs_r, k), _compute_block, tag="sparse_dense_bsrmm_block_v1"
190+
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1"
191191
)
192192
return te.compute(
193-
(num_blocks * bs_r, k),
194-
lambda m, n: bsrmm_block[idxd(n, bs_r), idxm(n, bs_r), m],
193+
(num_blocks * bs_r, m),
194+
lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n],
195195
tag="sparse_dense_bsrmm_v1",
196196
)
197197

tests/python/topi/python/test_topi_sparse.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,31 @@ def test_sparse_dense_csr():
272272
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
273273

274274

275+
def test_sparse_dense_csr_reverse():
276+
M, N, K, density = 1, 17, 47, 0.2
277+
X_np = np.random.randn(M, K).astype("float32")
278+
W_sp_np = sp.random(N, K, density=density, format="csr", dtype="float32")
279+
W_np = W_sp_np.todense()
280+
Y_np = W_np.dot(X_np.T)
281+
282+
W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype))
283+
W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
284+
W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
285+
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
286+
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr, sparse_lhs=True)
287+
s = te.create_schedule(Y.op)
288+
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
289+
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
290+
func(
291+
tvm.nd.array(X_np),
292+
tvm.nd.array(W_sp_np.data),
293+
tvm.nd.array(W_sp_np.indices),
294+
tvm.nd.array(W_sp_np.indptr),
295+
Y_tvm,
296+
)
297+
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
298+
299+
275300
def test_sparse_transpose_csr():
276301
N, density = 1023, 0.3
277302

@@ -368,6 +393,31 @@ def test_sparse_dense_bsr_relu(ctx, target):
368393
verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, False, ctx, target)
369394

370395

396+
def test_sparse_dense_bsr_reverse():
397+
M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
398+
X_np = np.random.randn(M, K).astype("float32")
399+
W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
400+
W_np = W_sp_np.todense()
401+
Y_np = W_np.dot(X_np.T)
402+
403+
W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype))
404+
W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
405+
W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
406+
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
407+
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr, sparse_lhs=True)
408+
s = te.create_schedule(Y.op)
409+
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
410+
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
411+
func(
412+
tvm.nd.array(X_np),
413+
tvm.nd.array(W_sp_np.data),
414+
tvm.nd.array(W_sp_np.indices),
415+
tvm.nd.array(W_sp_np.indptr),
416+
Y_tvm,
417+
)
418+
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
419+
420+
371421
@tvm.testing.uses_gpu
372422
def test_sparse_dense_bsr_randomized():
373423
for _ in range(20):
@@ -480,3 +530,5 @@ def test_sparse_dense_padded_alter_op():
480530
test_sparse_transpose_csr()
481531
test_sparse_dense_padded_cuda()
482532
test_sparse_dense_padded_alter_op()
533+
test_sparse_dense_csr_reverse()
534+
test_sparse_dense_bsr_reverse()

0 commit comments

Comments
 (0)