From 5e3a7e5c1b73247d3eafb0c44453d438ad139db9 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Mon, 10 Sep 2018 11:59:01 -0700 Subject: [PATCH] ONNX export - Clip operator (#12457) --- .../contrib/onnx/mx2onnx/_op_translations.py | 24 +++++++++++++++++++ .../onnx/export/onnx_backend_test.py | 3 ++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0960776251c4..3ffac96a14e1 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1057,6 +1057,30 @@ def convert_flatten(node, **kwargs): ) return [flatten_node] +@mx_op.register("clip") +def convert_clip(node, **kwargs): + """Map MXNet's Clip operator attributes to onnx's Clip operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + input_idx = kwargs["index_lookup"][node["inputs"][0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_idx].name + attrs = node["attrs"] + a_min = np.float(attrs.get('a_min', -np.inf)) + a_max = np.float(attrs.get('a_max', np.inf)) + + clip_node = helper.make_node( + "Clip", + [input_node], + [name], + name=name, + min=a_min, + max=a_max + ) + return [clip_node] + def scalar_op_helper(node, op_name, **kwargs): """Helper function for scalar arithmetic operations""" diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index 19bf6993e7cd..01ae09402ef5 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -89,7 +89,8 @@ 'test_operator_exp', 'test_operator_maxpool', 'test_operator_params', - 'test_operator_permute2' + 'test_operator_permute2', + 'test_clip' ] BASIC_MODEL_TESTS = [