Skip to content

S2d #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 93 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
1077352
[Relay][Frontend][TFLite] Add parser support for shape and range
dhruvaray Apr 14, 2020
28fcb2d
[RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops (#5316)
siju-samuel Apr 14, 2020
16d3da1
[TIR] Refactor MakePackedAPI to target dependent stage. (#5326)
tqchen Apr 14, 2020
2b6e845
[RELAY] Remove re-exports of tvm.transform (#5337)
tqchen Apr 15, 2020
747a4a8
[LLVM] Use llvm::FunctionCallee in IRBuilder::CreateCall with LLVM 11…
Apr 15, 2020
d7c977c
[CI] Fix build.sh to propagate --network=host to the docker build com…
leandron Apr 15, 2020
ab5afbc
[Runtime][Relay][Cleanup] Clean up for memory pass to enable heteroge…
jroesch Apr 15, 2020
c8e933e
Windows Support for cpp_rpc (#4857)
jmorrill Apr 15, 2020
19ce0a9
[PYTORCH]Take, Topk op support (#5332)
siju-samuel Apr 15, 2020
8a42257
[TOPI] Using x86 schedules for ARM conv2d. (#5334)
anijain2305 Apr 15, 2020
23a5e8e
[TOPI] Improve get_valid_count and nms performance for CUDA (#5339)
Laurawly Apr 15, 2020
e1a1f55
[PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (#…
tqchen Apr 15, 2020
d3d155c
[TIR] Remove ProducerConsumer and AllocateNode::new_expr (#5333)
tqchen Apr 15, 2020
ecdb00c
[BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)
Apr 15, 2020
8639224
[Tutorial, QNN] Add tutorial for loading quantized PyTorch model (#5321)
masahi Apr 15, 2020
e9ae136
[DOCS] Bring relay docs to the top-level flat view (#5343)
tqchen Apr 15, 2020
17b4961
[TOPI][PYTORCH]Logical & Bitwise operator support (#5341)
siju-samuel Apr 16, 2020
7448081
[RELAY][BYOC] Register pattern tables from external codegens (#5262)
mbaret Apr 16, 2020
b9aa07f
[RUNTIME][CRT] support DLTensor whose ndim == 0 (#5344)
windclarion Apr 16, 2020
38819e5
[BYOC][FIX] Fix typo in "default" (#5348)
mbaret Apr 16, 2020
b01fb67
enable tsim and fsim for GPU build (#5352)
zhiics Apr 16, 2020
5e3222f
[CRT]Compilation warnings fixed for 32bit and 64bit compilation (#5349)
siju-samuel Apr 16, 2020
23e3e9e
[PYTORCH]Tensor creation ops support (#5347)
siju-samuel Apr 17, 2020
64db78d
[Hexagon] Add hexagon_posix.cc to TVM/RT sources in the right place (…
Apr 17, 2020
9568e0b
[TOPI-ARM] Do not alter layout if layout is NHWC (#5350)
anijain2305 Apr 17, 2020
dbfd277
[TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size …
roastduck Apr 17, 2020
453da00
[RELAY][PYTORCH]GroupNorm op support added (#5358)
siju-samuel Apr 17, 2020
7f995ce
docker: Drop caffe2 download progess bars (#5359)
mshawcroft Apr 17, 2020
29da9ec
fix fuse over functions that are handled by external codegen (#5365)
zhiics Apr 18, 2020
be0c661
[RUNTIME] FastRPC interface for Hexagon runtime (#5353)
Apr 18, 2020
9707ae5
[TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified …
tqchen Apr 18, 2020
458814b
[TIR] Fix lower_warp_memory when there are >1 warp buffers (#5368)
roastduck Apr 19, 2020
eae387b
Add cuda target check to dense tensorcore schedule. (#5376)
jwfromm Apr 19, 2020
9248e15
Remove developer facing api from frontend exports. (#5375)
shoubhik Apr 19, 2020
f635fd5
[TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. (#5372)
tqchen Apr 19, 2020
b5925fe
[PYTORCH]Unary Ops (#5378)
siju-samuel Apr 20, 2020
f7ca70d
[TIR][REFACTOR] RewriteForTensorCore -> te/schedule (#5379)
tqchen Apr 20, 2020
a52ab12
[Blocksparse] Pipeline for lowering dense model to sparse-dense (#5377)
antinucleon Apr 20, 2020
118f943
[REFACTOR][TE] Inline -> te/schedule/operation_inline.h (#5386)
tqchen Apr 21, 2020
8499d01
[ARITH] Remove the legacy Simplify, migrate to Analyzer. (#5385)
tqchen Apr 21, 2020
1f7b94a
[ARITH] Remove legacy const pattern functions (#5387)
tqchen Apr 21, 2020
0969181
Add ability to have multiple copies of same input to onnx_inputs. (#5…
jwfromm Apr 21, 2020
7d52c1a
[Topi, ARM] Disbale Winograd for quantized tensors. (#5363)
anijain2305 Apr 21, 2020
c4bebb8
Fix test_ir_type. (#5390)
Apr 21, 2020
59c867d
Tf2 test fixups (#5391)
Apr 21, 2020
babdf7e
[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
tqchen Apr 21, 2020
f448dac
[LLVM] Use ArrayRef<int> in calls to CreateShuffleVector (#5399)
Apr 22, 2020
a407dd5
[KERAS]Minimum & AlphaDropout op support (#5380)
siju-samuel Apr 22, 2020
d7ec8e0
Factor out import of common tflite.Operator in tflite frontend. (#5355)
Apr 22, 2020
6dda520
[Fix] Remove the duplicate PrintIR pass in Relay (#5403)
icemelon Apr 22, 2020
848a1f5
Update dmlc-core to latest (#5401)
tqchen Apr 22, 2020
b9aa740
[TIR] Enhance Substitute, python bindings for Substitute/PostOrderVis…
tqchen Apr 22, 2020
6e8e8b1
[Relay] Fix memory leak when accessing NDArray (#5413)
icemelon Apr 22, 2020
697327c
Customize SI prefix in logging (#5411)
Apr 22, 2020
4b03597
[LLVM] Replace calls to Type::getVectorNumElements (#5398)
Apr 22, 2020
b87f73c
Don't remove() TempDirectory in __del__ after atexit hook runs. (#5414)
Apr 23, 2020
86fcfe0
[TIR][REFACTOR] Remove ir_pass in favor of analysis/transform. (#5415)
tqchen Apr 23, 2020
7893d20
[RUNTIME][CONTRIB] CoreML Runtime (#5283)
kazum Apr 23, 2020
05bfd1c
[DOCS] Migrate HLS documents from md to rst (#5419)
kazum Apr 23, 2020
e68bb67
fix [RUNTIME][VULKAN] vkBuffer released before memory copy command se…
samwyi Apr 23, 2020
248300b
[Frontend] Asymmetric padding of convolution support (#4803)
FrozenGene Apr 23, 2020
db9e412
[cuDNN] Add cuDNN grouped convolutions support (#5319)
wpan11nv Apr 23, 2020
efdd844
[CI] Migrate Tensorflow and Tensorflow lite in CI to 2.1.0 (#5392)
Apr 23, 2020
d5560b4
[DOCS] Migrate some markdowns to rst, fix sphinx3 warnings (#5416)
tqchen Apr 23, 2020
af079c1
[BYOC] Use Non-Recursive Visitor/Mutator (#5410)
comaniac Apr 23, 2020
0b0990c
[RFC] Pytest environment improvements (#5421)
Apr 23, 2020
d8f4641
[MXNET]DepthToSpace & SpaceToDepth Operator (#5408)
siju-samuel Apr 23, 2020
66c16cf
Add option to specify flatbuffers location (#5425)
michalpiszczek Apr 23, 2020
f0b5a9e
[FRONTEND][MXNET] support elemwise logic ops (#5361)
kazum Apr 24, 2020
9116de1
[PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass…
tqchen Apr 24, 2020
baf6674
[PYTORCH]where, addcdiv, addcmul op support (#5383)
siju-samuel Apr 24, 2020
155601b
[FRONTEND][TFLITE]Gather, StridedSlice op support added (#4788)
siju-samuel Apr 24, 2020
3b5c577
misc fixes for ROCm (pointer lifetime, runtime::String refactor) (#5431)
t-vi Apr 24, 2020
89a6237
Corrected TVM autotuning on GPU (#5432)
JishinMaster Apr 24, 2020
6361483
[RUNTIME][OBJECT] Introduce static slots for common objects. (#5423)
tqchen Apr 24, 2020
ca93121
[RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support (#5395)
siju-samuel Apr 25, 2020
1bebb6e
[PYTORCH]Rsub, Embedded, OneHot ops support (#5434)
siju-samuel Apr 25, 2020
1d331f4
fix miopen pad (#5433)
t-vi Apr 25, 2020
2c9da4d
Add TopK to ONNX Frontend (#5441)
mbrookhart Apr 25, 2020
ae9a581
[CodeGen] Cleanup generated code (#5424)
wpan11nv Apr 25, 2020
4c06e2e
[RELAY] Move frontend utils (#5345)
mbaret Apr 25, 2020
d186475
[KERAS]Embedding layer (#5444)
siju-samuel Apr 26, 2020
df52be0
[Docs] VTA install doc migration from md to rst (#5442)
tmoreau89 Apr 26, 2020
fa42562
Improve IntervalSet's floormod (#5367)
yongfeng-nv Apr 27, 2020
339c8ff
[ONNX]GatherNd, Round, IsNaN, IsInf (#5445)
siju-samuel Apr 27, 2020
4ecb171
[relay][topi] Add operation relay.nn.dilate() which calls topi.nn.dil…
notoraptor Apr 27, 2020
7717425
[Pytorch] fix translation of transpose when axis argument is as a lis…
n-nez Apr 27, 2020
2458cc3
[TOPI,RELAY][TFLITE] Sparse to dense operator
dhruvaray Apr 25, 2020
69e9bd7
use param name in documentation
dhruvaray Apr 27, 2020
aa0f81b
sphinx doc errors fixed
dhruvaray Apr 27, 2020
db3d04c
Merge remote-tracking branch 'upstream/master' into s2d
dhruvaray Apr 28, 2020
4e42714
incorporated code review comments
dhruvaray Apr 28, 2020
aaf46d4
Fixed indentation
dhruvaray Apr 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ List of operators
topi.expand_dims
topi.reshape
topi.unravel_index
topi.sparse_to_dense
topi.squeeze
topi.concatenate
topi.split
Expand Down Expand Up @@ -154,6 +155,7 @@ topi
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
.. autofunction:: topi.unravel_index
.. autofunction:: topi.sparse_to_dense
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
Expand Down
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ This level enables additional math and transform operators.
tvm.relay.tile
tvm.relay.reverse
tvm.relay.unravel_index
tvm.relay.sparse_to_dense


**Level 4: Broadcast and Reductions**
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
}
}; // struct SequenceMaskAttrs.

/*! \brief Attributes used in sparse_to_dense operator */
struct SparseToDenseAttrs : public tvm::AttrsNode<SparseToDenseAttrs> {
Array<Integer> output_shape;

TVM_DECLARE_ATTRS(SparseToDenseAttrs, "relay.attrs.SparseToDenseAttrs") {
TVM_ATTR_FIELD(output_shape)
.describe("Shape of the dense output tensor");
}
}; // struct SparseToDenseAttrs

/*! \brief Attributes for ndarray_size operator */
struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
DataType dtype;
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, model, subgraph, exp_tab):
'SOFTMAX': self.convert_softmax,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPACE_TO_DEPTH': self.convert_space_to_depth,
'SPARSE_TO_DENSE': self.convert_sparse_to_dense,
'SPLIT': self.convert_split,
'SQRT': self.convert_sqrt,
'SQUARE': self.convert_square,
Expand Down Expand Up @@ -2075,6 +2076,36 @@ def convert_space_to_depth(self, op):

return out

def convert_sparse_to_dense(self, op):
"""Convert TFLite SPARSE_TO_DENSE"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 4, "input tensors length should be 4"

for t in input_tensors:
assert not t.qnn_params, "Quantized input is not expected."

indices, values = input_tensors[0], input_tensors[2]
default_value = input_tensors[3]
output_shape = input_tensors[1]

for t in [indices, output_shape]:
t_type = t.tensor.Type()
assert t_type in (TensorType.INT32, TensorType.INT64)

out = _op.sparse_to_dense(
self.get_expr(indices.tensor_idx),
self.get_expr(values.tensor_idx),
self.get_expr(default_value.tensor_idx),
list(self.get_tensor_value(output_shape))
)

return out

def convert_prelu(self, op):
"""Convert TFLite PReLU"""
input_tensors = self.get_input_tensors(op)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
_reg.register_injective_schedule("unravel_index")
_reg.register_injective_schedule("sparse_to_dense")

# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,3 +884,32 @@ def unravel_index(indices, shape):
"""

return _make.unravel_index(indices, shape)

def sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
"""Converts a sparse representation into a dense tensor.

Example::
- sparse_to_dense([[0, 0], [1, 1]], [3, 3], 0, [2, 2]) = [[3, 0], [0, 3]]

Parameters
----------
sparse_indices : relay.Expr
A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values.

sparse_values : relay.Expr
A 0-D or 1-D tensor containing the sparse values for the sparse indices.

default_value : relay.Expr
A 0-D tensor containing the default value for the remaining locations.
Defaults to 0.

output_shape : relay.Expr
A list of integers. Shape of the dense output tensor.

Returns
-------
result : relay.Expr
Dense tensor of shape output_shape. Has the same type as sparse_values.
"""

return _make.sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape)
81 changes: 81 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2732,5 +2732,86 @@ Example::
.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// sparse_to_dense
TVM_REGISTER_NODE_TYPE(SparseToDenseAttrs);

bool SparseToDenseRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 3);
auto sparse_indices = types[0].as<TensorTypeNode>();
auto sparse_values = types[1].as<TensorTypeNode>();
auto default_value = types[2].as<TensorTypeNode>();
CHECK(sparse_indices != nullptr && sparse_values != nullptr && default_value != nullptr);

CHECK(sparse_indices->dtype.is_int())
<< "sparse_indices must be tensor of integers";

CHECK_LE(sparse_indices->shape.size(), 3)
<< "sparse_indices must be a tensor of either 0D, 1D or 2D";

CHECK_LE(sparse_values->shape.size(), 2)
<< "sparse_values must be a tensor of either 0D, 1D";

CHECK_EQ(default_value->shape.size(), 0)
<< "default_value should be a scalar";

const auto* param = attrs.as<SparseToDenseAttrs>();
CHECK(param != nullptr);

Array<IndexExpr> oshape;
for (auto i : param->output_shape) {
oshape.push_back(i);
}
reporter->Assign(types[3], TensorType(oshape, sparse_values->dtype));
return true;
}

Array<te::Tensor> SparseToDenseCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 3);
const auto* param = attrs.as<SparseToDenseAttrs>();
CHECK(param != nullptr);
return {topi::sparse_to_dense(inputs[0], inputs[1], inputs[2], param->output_shape)};
}

TVM_REGISTER_GLOBAL("relay.op._make.sparse_to_dense")
.set_body_typed([](Expr indices, Expr values, Expr default_value, Array<Integer> output_shape) {
auto attrs = make_object<SparseToDenseAttrs>();
attrs->output_shape = std::move(output_shape);
static const Op& op = Op::Get("sparse_to_dense");
return Call(op, {indices, values, default_value}, Attrs(attrs));
});

RELAY_REGISTER_OP("sparse_to_dense")
.describe(R"code(A dense tensor from a sparse representation.

- **sparse_indices**: A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values

- **sparse_values**: A 0-D or 1-D tensor containing the sparse values for the sparse indices.

- **default_value**: A 0-D tensor containing the default value for the remaining locations. Defaults to 0.

- **output_shape**: A list of integers. Shape of the dense output tensor.

Example::
- sparse_to_dense([0, 0], [1, 2]], [1, 2], 0, [3, 4]) = [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]

)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.set_support_level(3)
.set_attrs_type<SparseToDenseAttrs>()
.add_argument("sparse_indices", "Tensor", "Contains sparse indices.")
.add_argument("sparse_values", "Tensor", "Contains values for sparse indices.")
.add_argument("default_value", "Tensor", "Value to set for non-sparse indices. Defaults to 0.")
.add_type_rel("SparseToDense", SparseToDenseRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);

} // namespace relay
} // namespace tvm
8 changes: 4 additions & 4 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,7 +2409,7 @@ def verify_topk(input_dims, K, axis=-1):
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
helper.make_tensor_value_info("K", TensorProto.INT64, [1,])],
initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])],
outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)])

model = helper.make_model(graph, producer_name='topk_test')
Expand All @@ -2418,10 +2418,10 @@ def verify_topk(input_dims, K, axis=-1):
onnx_out = get_onnxruntime_output(model, [indata, k])

for target, ctx in [('llvm', tvm.cpu())]:
tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims],
tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims],
output_dtype=['float32', 'int64'])
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)

for n in [12, 32]:
for shape in [[n], [n, n], [n, n, n]]:
for k in [1, 5, 10]:
Expand All @@ -2430,7 +2430,7 @@ def verify_topk(input_dims, K, axis=-1):
verify_topk([n, n, n], 5, 0)
verify_topk([n, n, n], 5, 1)
verify_topk([n, n, n], 5, 2)


def test_roi_align():
def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0):
Expand Down
65 changes: 64 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ def test_all_resize():
if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
_test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)


#######################################################################
# Concatenation
# -------------
Expand Down Expand Up @@ -1583,6 +1582,69 @@ def test_forward_spacetodepth():
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)

#######################################################################
# Sparse To Dense
# ---------------
def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
with tf.Graph().as_default():
indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices")
values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values")
oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype))
dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value")

output = tf.sparse_to_dense(indices, oshape, values, dv)

compare_tflite_with_tvm(
[sparse_indices, sparse_values, default_value],
["indices", "values", "default_value"],
[indices, values, dv],
[output]
)

def test_forward_sparse_to_dense():
'''
Works in tvm/topi/tensorflow. But tflite converter breaks this test case
_test_sparse_to_dense(
np.int32(1),
np.int32(3),
np.int32(0),
np.array([5]).astype("int32")
)
'''

# vector
_test_sparse_to_dense(
np.array([0, 1, 4]).astype("int32"),
np.array([3, 3, 3]).astype("int32"),
np.int32(0),
np.array([5]).astype("int32")
)

# vector nXd
_test_sparse_to_dense(
np.array([[0, 0], [1, 2]]).astype("int32"),
np.array([1, 2]).astype("int32"),
np.int32(0),
np.array([3, 4]).astype("int32")
)

_test_sparse_to_dense(
np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"),
np.array([1, 2]).astype("int32"),
np.int32(4),
np.array([2, 3, 4]).astype("int32")
)

# floats
_test_sparse_to_dense(
np.array([0, 1, 4]).astype("int32"),
np.array([3.1, 3.1, 3.1]).astype("float32"),
np.float32(3.5),
np.array([5]).astype("int32")
)

#######################################################################
# Fully Connected
# ---------------
Expand Down Expand Up @@ -1932,6 +1994,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_sparse_to_dense()

# NN
test_forward_convolution()
Expand Down
49 changes: 48 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,52 @@ def verify_unravel_index(indices, shape, dtype):
# output which is inline with Tensorflow
# verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)

def test_sparse_to_dense():
def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected):

sparse_indices_data = np.array(sparse_indices)
sparse_values_data = np.array(sparse_values)
default_value_data = np.array(default_value)

a = relay.var("a", relay.TensorType(sparse_indices_data.shape, str(sparse_indices_data.dtype)))
b = relay.var("b", relay.TensorType(sparse_values_data.shape, str(sparse_values_data.dtype)))
c = relay.var("c", relay.TensorType(default_value_data.shape, str(default_value_data.dtype)))
d = relay.sparse_to_dense(a, b, c, output_shape)

zz = run_infer_type(d)
assert zz.checked_type == relay.ty.TensorType(output_shape, str(sparse_values_data.dtype))

func = relay.Function([a, b, c], d)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(
sparse_indices_data, sparse_values_data, default_value_data
)
tvm.testing.assert_allclose(op_res.asnumpy(), xpected, rtol=1e-5)


verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) # scalar
verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) # vector
verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) # nXd

verify_sparse_to_dense(
[[0, 0, 0], [1, 2, 3]],
[1, 2],
4,
[2, 3, 4],
[[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]]
) # nXd

verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) # floats

#negative test cases
#sparse indices should be ints
#verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
#sparse_values should be 0d or 1d only
#verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
#sparse_indices should not be > 2d tensor
#verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])

if __name__ == "__main__":
test_arange()
Expand Down Expand Up @@ -780,4 +826,5 @@ def verify_unravel_index(indices, shape, dtype):
test_gather_nd()
test_isfinite()
test_isinf()
test_unravel_index()
test_unravel_index()
test_sparse_to_dense()
Loading