From 46a8ab062fa525e77e88740821c4274ac0d0f76c Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 8 Oct 2020 10:56:42 -0600 Subject: [PATCH] Add Range op to ONNX, make tvm arange op support negative steps (#6647) --- python/tvm/relay/frontend/onnx.py | 14 +++++++++ python/tvm/relay/op/_transform.py | 5 ++- tests/python/frontend/onnx/test_forward.py | 36 ++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ce5084e1ece4..0598094398f7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1893,6 +1893,19 @@ def _impl_v1(cls, inputs, attr, params): return _op.topk(inputs[0], inputs[1], axis=axis) +class Range(OnnxOpConverter): + """Operator converter for Range""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + if len(inputs) != 3: + raise ValueError("Expect 3 input only") + + return _op.arange( + inputs[0], inputs[1], inputs[2], dtype=infer_type(inputs[0]).checked_type.dtype + ) + + class MaxRoiPool(OnnxOpConverter): """Operator converter for MaxRoiPool.""" @@ -2115,6 +2128,7 @@ def _get_convert_map(opset): "Or": Or.get_converter(opset), "Resize": Resize.get_converter(opset), "NonZero": NonZero.get_converter(opset), + "Range": Range.get_converter(opset), } diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a2fab248b2dd..a852bb0bd7a5 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -123,7 +123,10 @@ def compute_scatter_add(attrs, inputs, output_type): @script def _arange_shape_func(start, stop, step): out = output_tensor((1,), "int64") - out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0]))) + if step[0] < 0: + out[0] = int64(ceil_div((int64(start[0]) - int64(stop[0])), int64(-step[0]))) + else: + out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0]))) return out diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 799d8971b853..ae32012e42e8 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -410,6 +410,42 @@ def test_power(): _test_power_iteration((2, 3), (1, 3)) +def verify_range(start, limit, delta, dtype): + dtype_map = { + "float32": TensorProto.FLOAT, + "int32": TensorProto.INT32, + "int64": TensorProto.INT64, + } + dtype_onnx = dtype_map[dtype] + y = helper.make_node("Range", ["start", "limit", "delta"], ["output"]) + graph = helper.make_graph( + [y], + "range_test", + inputs=[ + helper.make_tensor_value_info("start", dtype_onnx, []), + helper.make_tensor_value_info("limit", dtype_onnx, []), + helper.make_tensor_value_info("delta", dtype_onnx, []), + ], + outputs=[ + helper.make_tensor_value_info( + "output", dtype_onnx, np.arange(start, limit, delta).shape + ) + ], + ) + model = helper.make_model(graph, producer_name="range_test") + inputs = [np.array(x).astype(dtype) for x in [start, limit, delta]] + verify_with_ort_with_inputs(model, inputs, use_vm=True) + + +@tvm.testing.uses_gpu +def test_range(): + for t in ["float32", "int32", "int64"]: + verify_range(0, 10, 1, t) + verify_range(2, 8, 2, t) + verify_range(-3, 6, 4, t) + verify_range(-2, -7, -1, t) + + @tvm.testing.uses_gpu def test_squeeze(): in_shape = (1, 3, 1, 3, 1, 1)