Skip to content

Commit 2658ebe

Browse files
Matthew BrookhartLily Orth-Smith
andauthored
Dynamic ONNX Importer (#6351)
* Change onnx importer to use dynamic upsampling3d (#3) fix pylint * Refactor ONNX frontend to be dynamic Make OneHot dynamic Support BatchMatMul with dynamically shaped inputs fix dynamic broadcast Add null checks to broadcast_to rel functions fail more isolated broadcast_to test use StructuralEqual instead of pointer comparisions in dynamic_to_static pass add an optional weight freeze argument to onnx importer convert onnx resize to dynamic op add dynamic expand to onnx importer add a shape_func for power fix BERTSquad, lint handle onnx graph initializer parameters more intelligently * Dynamic ONNX importer: Upsampling and Pad (#2) fix lint fix Call reference fix a type issue with expand fix a bad test refactor respond to review comments, fix batch matmul tests * black format * fix batch matmul test * add dynamic strided slice to the onnx importer * fix clip importer * fix qnn tutorial * fix bad merge, respond to review comments * add a simple dynamic model test * Add dynamic-shaped autopadding to convolution and pooling ops * fix dynamic issues in a few ops * fix pylint * disable tests onnxrt doesn't support * fix pytorch test * respond to review comments * add documentation about partially supporting dynamic shapes Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
1 parent a413458 commit 2658ebe

File tree

21 files changed

+957
-489
lines changed

21 files changed

+957
-489
lines changed

include/tvm/relay/transform.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,17 @@ TVM_DLL Pass SimplifyInference();
208208
*/
209209
TVM_DLL Pass FastMath();
210210

211+
/*!
212+
* \brief Find Dynamic ops and make them static
213+
*
214+
* Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces
215+
* them with static ops and re-performs type inference and constant folding. The pass repeats
216+
* itself until the graph stops changing or we run too many iterations.
217+
*
218+
* \return The pass.
219+
*/
220+
TVM_DLL Pass DynamicToStatic();
221+
211222
/*!
212223
* \brief Infer the type of an expression.
213224
*

include/tvm/topi/broadcast.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,19 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
5454
<< "\nvs\ninput: " << t;
5555
auto bh = detail::BroadcastShape(output_shape, t->shape);
5656
CHECK_EQ(output_shape.size(), bh.common_shape.size());
57+
Array<PrimExpr> oshape;
5758
for (size_t i = 0; i < output_shape.size(); ++i) {
58-
CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
59+
if (output_shape[i].as<tir::IntImmNode>() == nullptr) {
60+
oshape.push_back(output_shape[i]);
61+
} else {
62+
CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
63+
oshape.push_back(bh.common_shape[i]);
64+
}
5965
}
6066
auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
6167
return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
6268
};
63-
return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
64-
l, name, tag);
69+
return tvm::te::compute(oshape, l, name, tag);
6570
}
6671

6772
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \

0 commit comments

Comments
 (0)