Skip to content

Commit 3da94a0

Browse files
dhruvaraysiju-samueltqchenKrzysztof Parzyszekleandron
authored
S2d (#2)
* [Relay][Frontend][TFLite] Add parser support for shape and range Signed-off-by: Dhruva Ray <dhruvaray@gmail.com> * [RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops (apache#5316) * [RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops * Review comments * [TIR] Refactor MakePackedAPI to target dependent stage. (apache#5326) Previously MakePackedAPI was in the target independent stage, but never the less requires the device_type information that will be binded at a later target dependent stage. The previous implementation was due to the limitation of LoweredFunc which can not carry buffer_map info(so they have to be lowered right away). This is no longer the case after the unified IR refactor. This PR migrates MakePackedAPI to a target dependent stage and removes the un-necessary BindDevice pass. * [RELAY] Remove re-exports of tvm.transform (apache#5337) * [LLVM] Use llvm::FunctionCallee in IRBuilder::CreateCall with LLVM 11+ (apache#5338) The older variants of CreateCall have been deprecated and were recently removed from LLVM. This caused compilation failures. * [CI] Fix build.sh to propagate --network=host to the docker build command (apache#5336) * when passing --net=host to build.sh it needs to be also sent as --network=host to "docker build", so that both build and run will use the same network configuration * [Runtime][Relay][Cleanup] Clean up for memory pass to enable heterogenous execution support. (apache#5324) * Cleanup type pack and unpack for tuples. * Clean up the memory_pass using common helpers * Clean up memory.cc * Refactor pass * Add doc strings * Fix CPPlint * Fix PyLint * Fix * Apply suggestions from code review Co-Authored-By: Zhi <5145158+zhiics@users.noreply.github.com> * Fix typo Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com> * Windows Support for cpp_rpc (apache#4857) * Windows Support for cpp_rpc * Add missing patches that fix crashes under Windows * On Windows, use python to untar vs wsl * remove some CMakeLists.txt stuff * more minor CMakeLists.txt changes * Remove items from CMakeLists.txt * Minor CMakeLists.txt changes * More minor CMakeLists.txt changes * Even more minor CMakeLists.txt changes * Modify readme * [PYTORCH]Take, Topk op support (apache#5332) * [PYTORCH]take, topk op support * Ci Failure fix * [TOPI] Using x86 schedules for ARM conv2d. (apache#5334) * [TOPI] Improve get_valid_count and nms performance for CUDA (apache#5339) * get_valid_count updated to have correct results * speedup nms * update nms * revert back nms * recover one test for get_valid_count * [PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (apache#5335) * [TIR] Remove ProducerConsumer and AllocateNode::new_expr (apache#5333) * [TIR] Remove ProducerConsumer and AllocateNode::new_expr This PR removes two legacy IR parts in TIR that are deprecated. ProducerConsumer node only serves as a hint markup and may no longer be informative after extensive transformations in the pass. If necessary, we can add related info via AttrStmt. The new_expr field in the AllocateNode is deprecated since it can just be replaced by a LetStmt. - Remove dependencies of passes on ProducerConsumer. - Remove ProducerConsumer from the IR. - Remove the deprecated fields (new_expr, free_function) from AllocateNode. * Fix additional testcases * [BYOC] Prevent duplicate outputs in subgraph Tuple (apache#5320) * Fix duplicate output in partitiongraph * Add test case * Fix test_annotated_regions with duplicate compiler_end outputs * Revert "Fix duplicate output in partitiongraph" This reverts commit e1f8ef3. * Prevent duplicate outputs in Tuple in PartitionGraph * Fix lint * Add another test case for when regions are merged, and when TupleGetItem was duplicated * Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput * Use std::move for GetFunctionOutput. Fix typo with testcase name * Use tvm.transform.Sequential * [Tutorial, QNN] Add tutorial for loading quantized PyTorch model (apache#5321) * add pytorch tutorial code and doc stub * add more docs * formatting, more docs * typo fix * try make sphinx happy * add performance section * type and nit fix * format fix * [DOCS] Bring relay docs to the top-level flat view (apache#5343) - Changes most of the relay docs to use autosummary. - Bring relay API docs to the top-level flat view for easier discovery - Removed a few cases of re-exports. * [TOPI][PYTORCH]Logical & Bitwise operator support (apache#5341) * [RELAY][BYOC] Register pattern tables from external codegens (apache#5262) * [RELAY][BYOC] Register pattern tables from external codegens This adds utility functions to support registering and retrieving pattern tables used by MergeComposite for external codegens. Change-Id: I5be165a321440e48b15ff6aff4970e0c67496aaa * Updated DNNL tests to use pattern table mechanism * Removed pattern table standalone test * Change reg to _op * [RUNTIME][CRT] support DLTensor whose ndim == 0 (apache#5344) Signed-off-by: windclarion <windclarion@gmail.com> * [BYOC][FIX] Fix typo in "default" (apache#5348) Default annotations were incorrectly being named 'defualt' which results in them not being removed in PartitionGraph. * enable tsim and fsim for GPU build (apache#5352) * [CRT]Compilation warnings fixed for 32bit and 64bit compilation (apache#5349) * [PYTORCH]Tensor creation ops support (apache#5347) * [Hexagon] Add hexagon_posix.cc to TVM/RT sources in the right place (apache#5346) This file was added before the variable with TVM/RT was initialized. The initialization overwrote the addition. * [TOPI-ARM] Do not alter layout if layout is NHWC (apache#5350) * [TOPI-ARM] Do not alter layout if layout is NHWC * Add test. * [TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size (apache#5307) * support extent(threadIdx.x) < warp_size in lower_warp_memory * more docs for lower_warp_memory * [RELAY][PYTORCH]GroupNorm op support added (apache#5358) * docker: Drop caffe2 download progess bars (apache#5359) Change-Id: Ia15c3c8f41f75423814e559f6fdb062098f19464 * fix fuse over functions that are handled by external codegen (apache#5365) * [RUNTIME] FastRPC interface for Hexagon runtime (apache#5353) * [RUNTIME] FastRPC interface for Hexagon runtime Co-authored-by: Ravishankar Kolachana <quic_rkolacha@quicinc.com> Co-authored-by: Krzysztof Parzyszek <kparzysz@quicinc.com> * Explain store offset in a comment in launcher Co-authored-by: Abhikrant Sharma <quic_abhikran@quicinc.com> Co-authored-by: Ravishankar Kolachana <quic_rkolacha@quicinc.com> * [TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified IR pass manager. (apache#5364) - Migrate BoundCheckers and Simplify - Migrate RewriteUnsafeSelect and RemoveNoOp - Migrate UnrollLoop and StorageRewrite - Migrate InjectDoubleBuffer and InjectVirtualThread - Migrate LoopPartition and Vectorize - Migrate CoProcSync, LiftAttrScope, InjectCopyIntrin We still keep ir_pass registerations for now. Need a separate PR to refactor the parts before the StorageFlatten. * [TIR] Fix lower_warp_memory when there are >1 warp buffers (apache#5368) * fix recursion in lower_warp_memory * post-order mutation * Add cuda target check to dense tensorcore schedule. (apache#5376) * Remove developer facing api from frontend exports. (apache#5375) * [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. (apache#5372) * [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. te::Tensor is an useful object for tensor expression, but brings un-necessary reverse dependency in TIR nodes such as Provide and Realize. This PR is a first step to remove this dependency. We will use Buffer in all the places where the te::Tensor was used. The rough correspondence are: - Provide -> BufferStore - Realize -> BufferRealize - HalideCall -> BufferLoad. After this change, we can not use IRModule of PrimFuncs cleanly to represent TIR at any point of the optimizations. Buffer will serve as the abstraction for the TIR data models to represent the intermediate storages and their constraints. We still keep Realize/HalideCall and Provide as TIR nodes for now to make the change minimum. Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize the temporary IR generated by TE(which contains these nodes) to the TIR. The TIR optimizations are now mostly migrated to to the pass manager. Followup PRs are needed to migrate the remaining few passes. * Fix dev tutorial * [PYTORCH]Unary Ops (apache#5378) * [TIR][REFACTOR] RewriteForTensorCore -> te/schedule (apache#5379) * [TIR][REFACTIR] RewriteForTensorCore -> te/schedule RewriteForTensor depends on the schedule information, which makes it differ from a typical pass(which should get all the information from the input TIR). As a result, we refactor it as a SchedulePostProc step for now. We should revisit it later as we introduce more support for tensor core patterns in the TIR. * Fix VTA to fit the new IR Pattern * [Blocksparse] Pipeline for lowering dense model to sparse-dense (apache#5377) * [REFACTOR][TE] Inline -> te/schedule/operation_inline.h (apache#5386) Rationale: inline is a transformation used in te to rewrite its internal expressions. It is not a formal IRModule->IRModule transform pass. Also removed the python test as the test is covered by stage.compute_inline. * [ARITH] Remove the legacy Simplify, migrate to Analyzer. (apache#5385) The legacy Simplify/CanonicalSimplify are now a thin wrapper around the Analyzer. This PR removes these functions and migrated every place that requires simplification to enforce Analyzer creation. The new API would encourage more Analyzer sharing and potentially enable context-aware analyzer-based simplification. * [ARITH] Remove legacy const pattern functions (apache#5387) * Add ability to have multiple copies of same input to onnx_inputs. (apache#5389) * [Topi, ARM] Disbale Winograd for quantized tensors. (apache#5363) * [Topi, ARM] Disbale Winograd for quantized tensors. * Relaxing float * Fix test_ir_type. (apache#5390) * The void return type is not None/nullptr, it's VoidType or TupleType([]). * Tf2 test fixups (apache#5391) * Fix oversight in importing tf.compat.v1 as tf. * Actually disable test for lstm in TF2.1 Since the testing framework actually uses pytest, the version check needs to be moved. * [PTYTHON] Migrate VTA TIR passes to the new pass manager. (apache#5397) * [LLVM] Use ArrayRef<int> in calls to CreateShuffleVector (apache#5399) This switch was made in LLVM 11. Previously this function was expecting mask indices of type uint32_t. This variant is now deprecated. * [KERAS]Minimum & AlphaDropout op support (apache#5380) * Factor out import of common tflite.Operator in tflite frontend. (apache#5355) * Restructure imports in tflite frontend. These python modules are needed for every tflite file parsed. Factorize out imports of the common most ones. Now that the import of operator is common, asserts can be commonized. Loses 473 lines of duplication. * Only restrict to tflite.Operator * [Fix] Remove the duplicate PrintIR pass in Relay (apache#5403) * Update dmlc-core to latest (apache#5401) * [TIR] Enhance Substitute, python bindings for Substitute/PostOrderVisit/IRTransform. (apache#5400) Substitute now takes a std::function to customize more replacing behaviors. Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn> Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn> * [Relay] Fix memory leak when accessing NDArray (apache#5413) * Customize SI prefix in logging (apache#5411) * Customize SI prefix in logging * Include unit test * [LLVM] Replace calls to Type::getVectorNumElements (apache#5398) This function has recently been removed from LLVM 11. Use alternative way to obtain vector element count (VectorType::getNumElements) which works for all LLVM versions. * Don't remove() TempDirectory in __del__ after atexit hook runs. (apache#5414) * Use atexit to remove TempDirectory before interpreter shutdown. * Can't rely on complex functions from __del__ anyway. * Fixes warning message on my box: Exception ignored in: <function TempDirectory.__del__ at 0x12be10680> Traceback (most recent call last): File ".../tvm/python/tvm/contrib/util.py", line 55, in __del__ File ".../tvm/python/tvm/contrib/util.py", line 51, in remove File "/usr/local/opt/python/Frameworks/Python.framework/Versions/3.7/lib/python3.7/shutil.py", line 509, in rmtree AttributeError: 'NoneType' object has no attribute 'path' * [TIR][REFACTOR] Remove ir_pass in favor of analysis/transform. (apache#5415) This PR removes ir_pass(old style pass functions) in favor of analysis/transform(new style pass manager). * [RUNTIME][CONTRIB] CoreML Runtime (apache#5283) * [RUNTIME][CONTRIB] CoreML Runtime * fix lint * fix CI * use xcrun to compile coreml model * [DOCS] Migrate HLS documents from md to rst (apache#5419) * fix [RUNTIME][VULKAN] vkBuffer released before memory copy command send to GPU (apache#5388) (apache#5418) * [Frontend] Asymmetric padding of convolution support (apache#4803) * [cuDNN] Add cuDNN grouped convolutions support (apache#5319) Signed-off-by: Wei Pan <weip@nvidia.com> * [CI] Migrate Tensorflow and Tensorflow lite in CI to 2.1.0 (apache#5392) * Migrate Tensorflow and TFLite in the CI up to 1.15.2 The latest stable version of Tensorflow and Tensorflow lite in the 1.x series is 1.15.2. The tflite frontend is receiving support for versions of tflite > 1.14 but there is no consistent testing. There are 2 failures already in the source base with tf 1.15 and I'm concerned this will just get exacerbated over time if we don't have CI picking this up and I view this as a stepping stone towards stepping CI to TF2.x. The test failures that I have commented will get issues raised for them as issues to be fixed. * Comment out run of qnn_mobilenet_v3_net This is another test that fails with TFlite 1.15.2 * Skip the qnn_mobilenet_v3 test in the pytest fashion. * Switch docker versions to support Tensorflow 2.1.0 * Fix up pytest imports and usage. * Skip these tests currently for Tensorflow 2.1.0 * [DOCS] Migrate some markdowns to rst, fix sphinx3 warnings (apache#5416) * [DOCS] Migrate some markdowns to rst, fix sphinx3 warnings * Add note block * [BYOC] Use Non-Recursive Visitor/Mutator (apache#5410) * Non-Recursive AnnotatedTarget and MergeAnnotation * Non-Recursive AnnotatedRegionSet and RegionMerger * [RFC] Pytest environment improvements (apache#5421) * [RFC] Pass pytest options globally. In many places having a global pytest flag is useful . For me with the build and test of tvm , I would like to be able to globally pass in pytest options as part of development flow or CI flows where one would like to measure other things regularly that need measurements including pytest coverage data that I would like to experiment with across the stack. This has been achieved with an additional setup-pytest-env.sh file in tests/scripts rather than putting in something in every single task test script and something I would like to avoid. This now means the -v option to pytest is superfluous. I did consider having a pytest.ini file but that doesn't allow me to pass any old environment variable in and this seems to be the compromise. * Improve other use case documentation * Rationalize pytest environment. * Remove the setting from docker/with_same_user. * Take the opportunity to migrate common PYTHONPATH and TVM_PATH into the common environment setting. * Fixup vta fsim * Be more explicit with common PYTHONPATH * Fix python path for task_python_vta_fsim.sh properly * Fix nit in documentation. * [MXNET]DepthToSpace & SpaceToDepth Operator (apache#5408) * Add option to specify flatbuffers location (apache#5425) * [FRONTEND][MXNET] support elemwise logic ops (apache#5361) * [PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass str (apache#5426) To make runtime.String to work as naturally as possible in the python side, we make it sub-class the python's str object. Note that however, we cannot sub-class Object at the same time due to python's type layout constraint. We introduce a PyNativeObject class to handle this kind of object sub-classing and updated the FFI to handle PyNativeObject classes. * [PYTORCH]where, addcdiv, addcmul op support (apache#5383) * [PYTORCH]Where, addcdiv, addcmul op support * Review comments fixed * [FRONTEND][TFLITE]Gather, StridedSlice op support added (apache#4788) * [FRONTEND][TFLITE]Gather, StridedSlice op added * Review comments fixed * misc fixes for ROCm (pointer lifetime, runtime::String refactor) (apache#5431) * Corrected TVM autotuning on GPU (apache#5432) Added missing "tir" in tvm.tir.analysis.verify_gpu_code(f, kwargs) * [RUNTIME][OBJECT] Introduce static slots for common objects. (apache#5423) The _type_child_slots can be used to enable quick type checking optimization by checking the whether the type index is within the bound. This PR enables these static slots: - Introduce a static assert to avoid the scenario when a developer forget to _type_child_slots when the field is set for the type's parent. - Revamp and assign static type index to common runtime objects - Add a DumpTypeTable call to allow developer monitor the current situation of type table and offers suggestions for the slots(ideally the slots equals the number of children so there is no overflow. * [RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support (apache#5395) * [RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support * Review comment fixed * Gradient testcase added * [PYTORCH]Rsub, Embedded, OneHot ops support (apache#5434) * fix miopen pad (apache#5433) * [TOPI,RELAY][TFLITE] Sparse to dense operator Signed-off-by: Dhruva Ray <dhruvaray@gmail.com> * Add TopK to ONNX Frontend (apache#5441) * Add TopK to ONNX Frontend * respond to review comments * [CodeGen] Cleanup generated code (apache#5424) - remove unnecessary white spaces from storage kind - do not start a new scope for vectorization as temporary variables are alll uniquely generated. The above two changes make vectorized code much cleaner. Signed-off-by: Wei Pan <weip@nvidia.com> * [RELAY] Move frontend utils (apache#5345) * [RELAY] Move frontend utils The util file currently under frontend is used from outside of frontend (in qnn/op/legalizations). This suggests that the file should be pushed up to a higher level. The benefit from this change is that importing qnn no longer also imports all the frontends. * Inline get_scalar_from_constant Change-Id: I1cc64e9ecb0eadb6ac0f7b62e6ea174644af4ad4 * Remove util.py from Relay Change-Id: If9cd7cf3fc0bd1861a3a9b5604f338e084d8db96 * Shorten functions Change-Id: Ieb537d82e6ee52421ff05a90cd00a03679ffebf2 * Line length Change-Id: I1d216b7e73a060c4f118f5da50ce58b18eba907f * [KERAS]Embedding layer (apache#5444) * [Docs] VTA install doc migration from md to rst (apache#5442) * Improve IntervalSet's floormod (apache#5367) * use param name in documentation Signed-off-by: Dhruva Ray <dhruvaray@gmail.com> * [ONNX]GatherNd, Round, IsNaN, IsInf (apache#5445) * [relay][topi] Add operation relay.nn.dilate() which calls topi.nn.dilate() (apache#5331) * Add operation relay.nn.dilate() which calls topi.nn.dilate(). * Fix typo * Set op pattern to injective * sphinx doc errors fixed Signed-off-by: Dhruva Ray <dhruvaray@gmail.com> * [Pytorch] fix translation of transpose when axis argument is as a list (apache#5451) * incorporated code review comments Signed-off-by: Dhruva Ray <dhruvaray@gmail.com> * Fixed indentation Signed-off-by: Dhruva Ray <dhruvaray@gmail.com> Co-authored-by: Samuel <siju.samuel@huawei.com> Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com> Co-authored-by: Krzysztof Parzyszek <kparzysz@quicinc.com> Co-authored-by: Leandro Nunes <leandro.nunes@arm.com> Co-authored-by: Jared Roesch <jroesch@octoml.ai> Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com> Co-authored-by: jmorrill <jeremiah.morrill@gmail.com> Co-authored-by: Animesh Jain <anijain@umich.edu> Co-authored-by: Leyuan Wang <laurawly@gmail.com> Co-authored-by: Trevor Morris <trevmorr@amazon.com> Co-authored-by: masahi <masahi129@gmail.com> Co-authored-by: mbaret <55580676+mbaret@users.noreply.github.com> Co-authored-by: windclarion <windclarion@gmail.com> Co-authored-by: Tang, Shizhi <rd0x01@gmail.com> Co-authored-by: Marcus Shawcroft <marcus.shawcroft@arm.com> Co-authored-by: Abhikrant Sharma <quic_abhikran@quicinc.com> Co-authored-by: Ravishankar Kolachana <quic_rkolacha@quicinc.com> Co-authored-by: Josh Fromm <jwfromm@uw.edu> Co-authored-by: shoubhik <shoubhikbhatti@gmail.com> Co-authored-by: Bing Xu <antinucleon@gmail.com> Co-authored-by: Andrew Reusch <areusch@octoml.ai> Co-authored-by: Ramana Radhakrishnan <ramana.radhakrishnan@arm.com> Co-authored-by: Haichen Shen <shenhaichen@gmail.com> Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn> Co-authored-by: MORITA Kazutaka <morita.kazutaka@gmail.com> Co-authored-by: samwyi <samwyi@yahoo.com> Co-authored-by: Zhao Wu <zhaowu@apache.org> Co-authored-by: Wei Pan <60017475+wpan11nv@users.noreply.github.com> Co-authored-by: Cody Yu <comaniac0422@gmail.com> Co-authored-by: Michal Piszczek <imichaljp@gmail.com> Co-authored-by: Thomas Viehmann <tv.code@beamnet.de> Co-authored-by: JishinMaster <francois.turban@gmail.com> Co-authored-by: Matthew Brookhart <matthewbrookhart@gmail.com> Co-authored-by: Thierry Moreau <tmoreau@octoml.ai> Co-authored-by: yongfeng-nv <49211903+yongfeng-nv@users.noreply.github.com> Co-authored-by: notoraptor <notoraptor@users.noreply.github.com> Co-authored-by: Nikolay Nez <34389970+n-nez@users.noreply.github.com>
1 parent 8cbdcd6 commit 3da94a0

File tree

14 files changed

+414
-6
lines changed

14 files changed

+414
-6
lines changed

docs/api/python/topi.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ List of operators
5050
topi.expand_dims
5151
topi.reshape
5252
topi.unravel_index
53+
topi.sparse_to_dense
5354
topi.squeeze
5455
topi.concatenate
5556
topi.split
@@ -154,6 +155,7 @@ topi
154155
.. autofunction:: topi.expand_dims
155156
.. autofunction:: topi.reshape
156157
.. autofunction:: topi.unravel_index
158+
.. autofunction:: topi.sparse_to_dense
157159
.. autofunction:: topi.squeeze
158160
.. autofunction:: topi.concatenate
159161
.. autofunction:: topi.split

docs/langref/relay_op.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ This level enables additional math and transform operators.
130130
tvm.relay.tile
131131
tvm.relay.reverse
132132
tvm.relay.unravel_index
133+
tvm.relay.sparse_to_dense
133134

134135

135136
**Level 4: Broadcast and Reductions**

include/tvm/relay/attrs/transform.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,16 @@ struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
288288
}
289289
}; // struct SequenceMaskAttrs.
290290

291+
/*! \brief Attributes used in sparse_to_dense operator */
292+
struct SparseToDenseAttrs : public tvm::AttrsNode<SparseToDenseAttrs> {
293+
Array<Integer> output_shape;
294+
295+
TVM_DECLARE_ATTRS(SparseToDenseAttrs, "relay.attrs.SparseToDenseAttrs") {
296+
TVM_ATTR_FIELD(output_shape)
297+
.describe("Shape of the dense output tensor");
298+
}
299+
}; // struct SparseToDenseAttrs
300+
291301
/*! \brief Attributes for ndarray_size operator */
292302
struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
293303
DataType dtype;

python/tvm/relay/frontend/tflite.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(self, model, subgraph, exp_tab):
124124
'SOFTMAX': self.convert_softmax,
125125
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
126126
'SPACE_TO_DEPTH': self.convert_space_to_depth,
127+
'SPARSE_TO_DENSE': self.convert_sparse_to_dense,
127128
'SPLIT': self.convert_split,
128129
'SQRT': self.convert_sqrt,
129130
'SQUARE': self.convert_square,
@@ -2135,6 +2136,36 @@ def convert_space_to_depth(self, op):
21352136

21362137
return out
21372138

2139+
def convert_sparse_to_dense(self, op):
2140+
"""Convert TFLite SPARSE_TO_DENSE"""
2141+
try:
2142+
from tflite.TensorType import TensorType
2143+
except ImportError:
2144+
raise ImportError("The tflite package must be installed")
2145+
2146+
input_tensors = self.get_input_tensors(op)
2147+
assert len(input_tensors) == 4, "input tensors length should be 4"
2148+
2149+
for t in input_tensors:
2150+
assert not t.qnn_params, "Quantized input is not expected."
2151+
2152+
indices, values = input_tensors[0], input_tensors[2]
2153+
default_value = input_tensors[3]
2154+
output_shape = input_tensors[1]
2155+
2156+
for t in [indices, output_shape]:
2157+
t_type = t.tensor.Type()
2158+
assert t_type in (TensorType.INT32, TensorType.INT64)
2159+
2160+
out = _op.sparse_to_dense(
2161+
self.get_expr(indices.tensor_idx),
2162+
self.get_expr(values.tensor_idx),
2163+
self.get_expr(default_value.tensor_idx),
2164+
list(self.get_tensor_value(output_shape))
2165+
)
2166+
2167+
return out
2168+
21382169
def convert_prelu(self, op):
21392170
"""Convert TFLite PReLU"""
21402171
input_tensors = self.get_input_tensors(op)

python/tvm/relay/op/_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
_reg.register_injective_schedule("one_hot")
5656
_reg.register_reduce_schedule("collapse_sum_like")
5757
_reg.register_injective_schedule("unravel_index")
58+
_reg.register_injective_schedule("sparse_to_dense")
5859

5960
# concatenate
6061
_reg.register_schedule("concatenate", strategy.schedule_concatenate)

python/tvm/relay/op/transform.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,3 +884,32 @@ def unravel_index(indices, shape):
884884
"""
885885

886886
return _make.unravel_index(indices, shape)
887+
888+
def sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
889+
"""Converts a sparse representation into a dense tensor.
890+
891+
Example::
892+
- sparse_to_dense([[0, 0], [1, 1]], [3, 3], 0, [2, 2]) = [[3, 0], [0, 3]]
893+
894+
Parameters
895+
----------
896+
sparse_indices : relay.Expr
897+
A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values.
898+
899+
sparse_values : relay.Expr
900+
A 0-D or 1-D tensor containing the sparse values for the sparse indices.
901+
902+
default_value : relay.Expr
903+
A 0-D tensor containing the default value for the remaining locations.
904+
Defaults to 0.
905+
906+
output_shape : relay.Expr
907+
A list of integers. Shape of the dense output tensor.
908+
909+
Returns
910+
-------
911+
result : relay.Expr
912+
Dense tensor of shape output_shape. Has the same type as sparse_values.
913+
"""
914+
915+
return _make.sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape)

src/relay/op/tensor/transform.cc

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2732,5 +2732,86 @@ Example::
27322732
.set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
27332733
.set_attr<TOpPattern>("TOpPattern", kInjective);
27342734

2735+
// sparse_to_dense
2736+
TVM_REGISTER_NODE_TYPE(SparseToDenseAttrs);
2737+
2738+
bool SparseToDenseRel(const Array<Type>& types,
2739+
int num_inputs,
2740+
const Attrs& attrs,
2741+
const TypeReporter& reporter) {
2742+
CHECK_EQ(num_inputs, 3);
2743+
auto sparse_indices = types[0].as<TensorTypeNode>();
2744+
auto sparse_values = types[1].as<TensorTypeNode>();
2745+
auto default_value = types[2].as<TensorTypeNode>();
2746+
CHECK(sparse_indices != nullptr && sparse_values != nullptr && default_value != nullptr);
2747+
2748+
CHECK(sparse_indices->dtype.is_int())
2749+
<< "sparse_indices must be tensor of integers";
2750+
2751+
CHECK_LE(sparse_indices->shape.size(), 3)
2752+
<< "sparse_indices must be a tensor of either 0D, 1D or 2D";
2753+
2754+
CHECK_LE(sparse_values->shape.size(), 2)
2755+
<< "sparse_values must be a tensor of either 0D, 1D";
2756+
2757+
CHECK_EQ(default_value->shape.size(), 0)
2758+
<< "default_value should be a scalar";
2759+
2760+
const auto* param = attrs.as<SparseToDenseAttrs>();
2761+
CHECK(param != nullptr);
2762+
2763+
Array<IndexExpr> oshape;
2764+
for (auto i : param->output_shape) {
2765+
oshape.push_back(i);
2766+
}
2767+
reporter->Assign(types[3], TensorType(oshape, sparse_values->dtype));
2768+
return true;
2769+
}
2770+
2771+
Array<te::Tensor> SparseToDenseCompute(const Attrs& attrs,
2772+
const Array<te::Tensor>& inputs,
2773+
const Type& out_type) {
2774+
CHECK_EQ(inputs.size(), 3);
2775+
const auto* param = attrs.as<SparseToDenseAttrs>();
2776+
CHECK(param != nullptr);
2777+
return {topi::sparse_to_dense(inputs[0], inputs[1], inputs[2], param->output_shape)};
2778+
}
2779+
2780+
TVM_REGISTER_GLOBAL("relay.op._make.sparse_to_dense")
2781+
.set_body_typed([](Expr indices, Expr values, Expr default_value, Array<Integer> output_shape) {
2782+
auto attrs = make_object<SparseToDenseAttrs>();
2783+
attrs->output_shape = std::move(output_shape);
2784+
static const Op& op = Op::Get("sparse_to_dense");
2785+
return Call(op, {indices, values, default_value}, Attrs(attrs));
2786+
});
2787+
2788+
RELAY_REGISTER_OP("sparse_to_dense")
2789+
.describe(R"code(A dense tensor from a sparse representation.
2790+
2791+
- **sparse_indices**: A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values
2792+
2793+
- **sparse_values**: A 0-D or 1-D tensor containing the sparse values for the sparse indices.
2794+
2795+
- **default_value**: A 0-D tensor containing the default value for the remaining locations. Defaults to 0.
2796+
2797+
- **output_shape**: A list of integers. Shape of the dense output tensor.
2798+
2799+
Example::
2800+
- sparse_to_dense([0, 0], [1, 2]], [1, 2], 0, [3, 4]) = [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]
2801+
2802+
)code" TVM_ADD_FILELINE)
2803+
.set_num_inputs(3)
2804+
.set_support_level(3)
2805+
.set_attrs_type<SparseToDenseAttrs>()
2806+
.add_argument("sparse_indices", "Tensor", "Contains sparse indices.")
2807+
.add_argument("sparse_values", "Tensor", "Contains values for sparse indices.")
2808+
.add_argument("default_value", "Tensor", "Value to set for non-sparse indices. Defaults to 0.")
2809+
.add_type_rel("SparseToDense", SparseToDenseRel)
2810+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
2811+
.set_attr<TOpPattern>("TOpPattern", kOpaque)
2812+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
2813+
ElemwiseArbitraryLayout)
2814+
.set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);
2815+
27352816
} // namespace relay
27362817
} // namespace tvm

tests/python/frontend/onnx/test_forward.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,7 +2409,7 @@ def verify_topk(input_dims, K, axis=-1):
24092409
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)),
24102410
helper.make_tensor_value_info("K", TensorProto.INT64, [1,])],
24112411
initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])],
2412-
outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
2412+
outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims),
24132413
helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)])
24142414

24152415
model = helper.make_model(graph, producer_name='topk_test')
@@ -2418,10 +2418,10 @@ def verify_topk(input_dims, K, axis=-1):
24182418
onnx_out = get_onnxruntime_output(model, [indata, k])
24192419

24202420
for target, ctx in [('llvm', tvm.cpu())]:
2421-
tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims],
2421+
tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims],
24222422
output_dtype=['float32', 'int64'])
24232423
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)
2424-
2424+
24252425
for n in [12, 32]:
24262426
for shape in [[n], [n, n], [n, n, n]]:
24272427
for k in [1, 5, 10]:
@@ -2430,7 +2430,7 @@ def verify_topk(input_dims, K, axis=-1):
24302430
verify_topk([n, n, n], 5, 0)
24312431
verify_topk([n, n, n], 5, 1)
24322432
verify_topk([n, n, n], 5, 2)
2433-
2433+
24342434

24352435
def test_roi_align():
24362436
def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0):

tests/python/frontend/tflite/test_forward.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,6 @@ def test_all_resize():
766766
if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
767767
_test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)
768768

769-
770769
#######################################################################
771770
# Range
772771
# -----
@@ -1702,6 +1701,69 @@ def test_forward_spacetodepth():
17021701
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
17031702
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)
17041703

1704+
#######################################################################
1705+
# Sparse To Dense
1706+
# ---------------
1707+
def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
1708+
# tflite 1.13 convert method does not accept empty shapes
1709+
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
1710+
with tf.Graph().as_default():
1711+
indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices")
1712+
values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values")
1713+
oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype))
1714+
dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value")
1715+
1716+
output = tf.sparse_to_dense(indices, oshape, values, dv)
1717+
1718+
compare_tflite_with_tvm(
1719+
[sparse_indices, sparse_values, default_value],
1720+
["indices", "values", "default_value"],
1721+
[indices, values, dv],
1722+
[output]
1723+
)
1724+
1725+
def test_forward_sparse_to_dense():
1726+
'''
1727+
Works in tvm/topi/tensorflow. But tflite converter breaks this test case
1728+
_test_sparse_to_dense(
1729+
np.int32(1),
1730+
np.int32(3),
1731+
np.int32(0),
1732+
np.array([5]).astype("int32")
1733+
)
1734+
'''
1735+
1736+
# vector
1737+
_test_sparse_to_dense(
1738+
np.array([0, 1, 4]).astype("int32"),
1739+
np.array([3, 3, 3]).astype("int32"),
1740+
np.int32(0),
1741+
np.array([5]).astype("int32")
1742+
)
1743+
1744+
# vector nXd
1745+
_test_sparse_to_dense(
1746+
np.array([[0, 0], [1, 2]]).astype("int32"),
1747+
np.array([1, 2]).astype("int32"),
1748+
np.int32(0),
1749+
np.array([3, 4]).astype("int32")
1750+
)
1751+
1752+
_test_sparse_to_dense(
1753+
np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"),
1754+
np.array([1, 2]).astype("int32"),
1755+
np.int32(4),
1756+
np.array([2, 3, 4]).astype("int32")
1757+
)
1758+
1759+
# floats
1760+
_test_sparse_to_dense(
1761+
np.array([0, 1, 4]).astype("int32"),
1762+
np.array([3.1, 3.1, 3.1]).astype("float32"),
1763+
np.float32(3.5),
1764+
np.array([5]).astype("int32")
1765+
)
1766+
17051767
#######################################################################
17061768
# Fully Connected
17071769
# ---------------
@@ -2055,6 +2117,7 @@ def test_forward_mediapipe_hand_landmark():
20552117
test_forward_stridedslice()
20562118
test_forward_depthtospace()
20572119
test_forward_spacetodepth()
2120+
test_forward_sparse_to_dense()
20582121

20592122
# NN
20602123
test_forward_convolution()

tests/python/relay/test_op_level3.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,52 @@ def verify_unravel_index(indices, shape, dtype):
747747
# output which is inline with Tensorflow
748748
# verify_unravel_index([0, 1, 2, 5], [2, 2], dtype)
749749

750+
def test_sparse_to_dense():
751+
def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected):
752+
753+
sparse_indices_data = np.array(sparse_indices)
754+
sparse_values_data = np.array(sparse_values)
755+
default_value_data = np.array(default_value)
756+
757+
a = relay.var("a", relay.TensorType(sparse_indices_data.shape, str(sparse_indices_data.dtype)))
758+
b = relay.var("b", relay.TensorType(sparse_values_data.shape, str(sparse_values_data.dtype)))
759+
c = relay.var("c", relay.TensorType(default_value_data.shape, str(default_value_data.dtype)))
760+
d = relay.sparse_to_dense(a, b, c, output_shape)
761+
762+
zz = run_infer_type(d)
763+
assert zz.checked_type == relay.ty.TensorType(output_shape, str(sparse_values_data.dtype))
764+
765+
func = relay.Function([a, b, c], d)
766+
for target, ctx in ctx_list():
767+
for kind in ["graph", "debug"]:
768+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
769+
op_res = intrp.evaluate(func)(
770+
sparse_indices_data, sparse_values_data, default_value_data
771+
)
772+
tvm.testing.assert_allclose(op_res.asnumpy(), xpected, rtol=1e-5)
773+
774+
775+
verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) # scalar
776+
verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) # vector
777+
verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) # nXd
778+
779+
verify_sparse_to_dense(
780+
[[0, 0, 0], [1, 2, 3]],
781+
[1, 2],
782+
4,
783+
[2, 3, 4],
784+
[[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]]
785+
) # nXd
786+
787+
verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) # floats
788+
789+
#negative test cases
790+
#sparse indices should be ints
791+
#verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
792+
#sparse_values should be 0d or 1d only
793+
#verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
794+
#sparse_indices should not be > 2d tensor
795+
#verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
750796

751797
if __name__ == "__main__":
752798
test_arange()
@@ -780,4 +826,5 @@ def verify_unravel_index(indices, shape, dtype):
780826
test_gather_nd()
781827
test_isfinite()
782828
test_isinf()
783-
test_unravel_index()
829+
test_unravel_index()
830+
test_sparse_to_dense()

0 commit comments

Comments
 (0)