Skip to content

Commit f198c5f

Browse files
author
Matthew Brookhart
committed
add dynamic strided slice to the onnx importer
1 parent 2892e6a commit f198c5f

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,6 @@ def _impl_v9(cls, inputs, attr, params):
945945
return out
946946

947947

948-
949948
class Shape(OnnxOpConverter):
950949
"""Operator converter for Shape."""
951950

@@ -1047,24 +1046,35 @@ def _impl_v1(cls, inputs, attr, params):
10471046

10481047
@classmethod
10491048
def _impl_v10(cls, inputs, attr, params):
1050-
attrs = {"starts": inputs[1], "ends": inputs[2]}
1051-
if len(inputs) >= 4:
1052-
attrs["axes"] = inputs[3]
1053-
attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()}
1054-
attrs = {
1055-
k: params[v[1]].asnumpy()
1056-
if v[1] in params
1057-
else infer_value_simulated(v[0], params).asnumpy()
1058-
for (k, v) in attrs.items()
1059-
}
1049+
starts = inputs[1]
1050+
ends = inputs[2]
1051+
axes = inputs[3]
1052+
steps = inputs[4]
1053+
1054+
data_rank = len(infer_shape(inputs[0]))
10601055

10611056
# Update the starts and ends according to axes if required.
1062-
if "axes" in attrs:
1063-
if max(attrs["axes"] + 1) != len(attrs["axes"]):
1064-
new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"])
1065-
attrs["starts"] = new_starts
1066-
attrs["ends"] = new_ends
1067-
return _op.strided_slice(inputs[0], begin=list(attrs["starts"]), end=list(attrs["ends"]))
1057+
if axes is not None:
1058+
data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
1059+
starts = _op.scatter(
1060+
_op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype),
1061+
axes,
1062+
starts,
1063+
axis=0,
1064+
)
1065+
ends = _op.scatter(data_shape, axes, ends, axis=0)
1066+
if steps is not None:
1067+
steps = _op.scatter(
1068+
_op.const([1] * data_rank, dtype=infer_type(steps).checked_type.dtype),
1069+
axes,
1070+
steps,
1071+
axis=0,
1072+
)
1073+
1074+
if steps is None:
1075+
steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype)
1076+
1077+
return _op.strided_slice(inputs[0], starts, ends, steps)
10681078

10691079

10701080
class Gather(OnnxOpConverter):
@@ -1406,7 +1416,6 @@ def _impl_v6(cls, inputs, attr, params):
14061416
return _op.tile(inputs[0], inputs[1])
14071417

14081418

1409-
14101419
class Erf(OnnxOpConverter):
14111420
"""Operator converter for Erf"""
14121421

tests/python/frontend/onnx/test_forward.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,12 +599,13 @@ def add_noop_to_input_attr(attr_name, attr):
599599
model = helper.make_model(graph, producer_name="slice_test")
600600

601601
for target, ctx in tvm.testing.enabled_targets():
602-
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10)
602+
tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=10, freeze_params=True)
603603

604604
tvm.testing.assert_allclose(outdata, tvm_out)
605605

606606

607-
@tvm.testing.uses_gpu
607+
# TODO(mbrookhart): enable once VM supports heterogenous execution
608+
# @tvm.testing.uses_gpu
608609
def test_slice():
609610
x = np.random.randn(20, 10, 5).astype(np.float32)
610611
_test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))

0 commit comments

Comments
 (0)