Skip to content

Commit 73e38a8

Browse files
srkreddy1238wweic
authored andcommitted
[FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. (apache#2850)
* [FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. * * test cases * * ci error
1 parent bbcbcbb commit 73e38a8

File tree

3 files changed

+57
-27
lines changed

3 files changed

+57
-27
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ def __call__(self, inputs, attrs, *args):
321321
else:
322322
assert callable(self._op_name), "op_name can either be string or callable"
323323
op_name = self._op_name(attrs)
324+
325+
# ignore 'tvm_custom' always
326+
self._ignores.append('tvm_custom')
327+
324328
# convert attributes
325329
new_attrs = {}
326330
for k in attrs.keys():
@@ -329,7 +333,8 @@ def __call__(self, inputs, attrs, *args):
329333
elif k in self._disables:
330334
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
331335
elif k in self._ignores:
332-
logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)
336+
if k != 'tvm_custom':
337+
logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name)
333338
elif k in self._transforms:
334339
new_name, defaults, transform = self._parse_default(self._transforms[k])
335340
if defaults is None:
@@ -416,4 +421,6 @@ def __init__(self, new_name):
416421
self._new_name = new_name
417422

418423
def __call__(self, inputs, attrs, *args):
424+
if 'tvm_custom' in attrs:
425+
attrs.pop('tvm_custom')
419426
return get_relay_op(self._new_name)(*inputs, **attrs)

python/tvm/relay/frontend/onnx.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _impl_v1(cls, inputs, attr, params):
106106
'pads': ('padding', (0, 0), revert_caffe2_pad)
107107
},
108108
# very weird attributes here in onnx, force check
109-
ignores=['dilations'],
109+
ignores=['dilations', 'auto_pad'],
110110
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
111111
extras={'ceil_mode': False},
112112
custom_check=dimension_constraint())(inputs, attr, params)
@@ -160,6 +160,7 @@ def _impl_v1(cls, inputs, attr, params):
160160
'dilations': ('dilation', (0, 0)),
161161
'pads': ('padding', (0, 0), revert_caffe2_pad),
162162
'group': ('groups', 1)},
163+
ignores=['auto_pad'],
163164
custom_check=dimension_constraint())(inputs[:2], attr, params)
164165
use_bias = len(inputs) == 3
165166
if use_bias:
@@ -332,7 +333,21 @@ def _impl_v1(cls, inputs, attr, params):
332333
shape = tuple(params[inputs[1].name_hint].asnumpy())
333334
out = _op.reshape(inputs[0], shape)
334335
else:
335-
out = _op.reshape_like(inputs[0], inputs[1])
336+
# Try to infer shape by precompute prune if possible.
337+
# TODO: good to check inputs to be in params.
338+
# to be enhanced when relay support list_input_names API of NNVM
339+
logging.warning("Infering Reshape argument by precompute")
340+
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
341+
with tvm.relay.build_config(opt_level=0):
342+
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
343+
ctx = tvm.context("llvm", 0)
344+
from tvm.contrib import graph_runtime
345+
m = graph_runtime.create(graph, lib, ctx)
346+
m.set_input(**params)
347+
m.run()
348+
params_new = m.get_output(0)
349+
inputs.pop(1)
350+
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))
336351

337352
return out
338353

@@ -477,10 +492,7 @@ class Shape(OnnxOpConverter):
477492

478493
@classmethod
479494
def _impl_v1(cls, inputs, attr, params):
480-
# Result of this operator is prominently used by reshape operator.
481-
# Just pass the input as it is so that reshape_like can be used there.
482-
logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)")
483-
return inputs[0]
495+
return _op.shape_of(inputs[0])
484496

485497
class Cast(OnnxOpConverter):
486498
""" Operator converter for Cast.
@@ -494,7 +506,7 @@ def _impl_v1(cls, inputs, attr, params):
494506
def _impl_v5(cls, inputs, attr, params):
495507
try:
496508
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
497-
attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
509+
attr['to'] = str(TENSOR_TYPE_TO_NP_TYPE[attr['to']])
498510
except ImportError as e:
499511
raise ImportError(
500512
"Unable to import onnx.mapping which is required {}".format(e))
@@ -674,6 +686,11 @@ class ReduceMean(Reduce):
674686
"""
675687
name = 'mean'
676688

689+
class ReduceProd(Reduce):
690+
""" Operator converter for ArgMax.
691+
"""
692+
name = 'prod'
693+
677694
class ArgMax(OnnxOpConverter):
678695
""" Operator converter for ArgMax.
679696
"""
@@ -826,6 +843,7 @@ def _get_convert_map(opset):
826843
'ReduceMin': ReduceMin.get_converter(opset),
827844
'ReduceSum': ReduceSum.get_converter(opset),
828845
'ReduceMean': ReduceMean.get_converter(opset),
846+
'ReduceProd': ReduceProd.get_converter(opset),
829847
# 'ReduceProd'
830848
# 'ReduceLogSumExp'
831849
'ArgMax': ArgMax.get_converter(opset),
@@ -842,8 +860,7 @@ def _get_convert_map(opset):
842860
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
843861
'Unsqueeze': Unsqueeze.get_converter(opset),
844862
'Pad': Pad.get_converter(opset),
845-
# TODO(zhreshold) Shape op is implemented as bypass op in relay
846-
# 'Shape': Shape.get_converter(opset),
863+
'Shape': Shape.get_converter(opset),
847864
}
848865

849866

@@ -883,6 +900,7 @@ def from_onnx(self, graph, opset):
883900
----------
884901
graph : onnx protobuf object
885902
The loaded onnx graph
903+
886904
opset : opset version
887905
888906
Returns
@@ -911,12 +929,12 @@ def from_onnx(self, graph, opset):
911929
dtype=self._params[i_name].dtype)
912930
else:
913931
self._num_input += 1
914-
shape = self._shape[i_name] if i_name in self._shape else ()
932+
tshape = self._shape[i_name] if i_name in self._shape else ()
915933
if isinstance(self._dtype, dict):
916934
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
917935
else:
918936
dtype = d_type
919-
self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype)
937+
self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
920938
# construct nodes, nodes are stored as directed acyclic graph
921939
for node in graph.node:
922940
op_name = node.op_type
@@ -936,6 +954,10 @@ def from_onnx(self, graph, opset):
936954
self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
937955
inputs.append(self._nodes[i_name])
938956

957+
i_name = self._parse_value_proto(node)
958+
attr['tvm_custom'] = {}
959+
attr['tvm_custom']['name'] = i_name
960+
939961
op = self._convert_operator(op_name, inputs, attr, opset)
940962
node_output = self._fix_outputs(op_name, node.output)
941963
if not isinstance(op, _expr.TupleWrapper):

tests/python/frontend/onnx/test_forward.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,35 +113,36 @@ def test_reshape():
113113

114114
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
115115

116-
def test_reshape_like():
116+
def test_shape():
117117
in_shape = (4, 3, 3, 4)
118-
ref_shape = (3, 4, 4, 3)
118+
ref_shape = (6, 2, 4, 3)
119119

120-
ref_array = np.random.uniform(size=ref_shape).astype('float32')
120+
ref_array = np.array(ref_shape)
121121
ref_node = onnx.helper.make_node('Constant',
122122
inputs=[],
123123
outputs=['ref_in'],
124124
value=onnx.helper.make_tensor(name = 'const_tensor',
125-
data_type = onnx.TensorProto.FLOAT,
125+
data_type = onnx.TensorProto.INT32,
126126
dims = ref_array.shape,
127-
vals = ref_array.flatten().astype(float)))
128-
copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"])
129-
reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
127+
vals = ref_array.flatten().astype(int)))
128+
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
129+
130+
shape_node = helper.make_node("Shape", ['out'], ['final_out'])
130131

131-
graph = helper.make_graph([ref_node, copy_node, reshape_node],
132-
"reshape_like_test",
132+
graph = helper.make_graph([ref_node, reshape_node, shape_node],
133+
"shape_test",
133134
inputs = [helper.make_tensor_value_info("in",
134135
TensorProto.FLOAT, list(in_shape))],
135-
outputs = [helper.make_tensor_value_info("out",
136+
outputs = [helper.make_tensor_value_info("final_out",
136137
TensorProto.FLOAT, list(ref_shape))])
137138

138-
model = helper.make_model(graph, producer_name='reshape_like_test')
139+
model = helper.make_model(graph, producer_name='shape_test')
139140

140141
for target, ctx in ctx_list():
141-
x = np.random.uniform(size=in_shape).astype('float32')
142-
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
142+
x = np.random.uniform(size=in_shape).astype('int32')
143+
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32')
143144

144-
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
145+
tvm.testing.assert_allclose(ref_shape, tvm_out)
145146

146147
def _test_power_iteration(x_shape, y_shape):
147148
if isinstance(y_shape, int):
@@ -995,7 +996,7 @@ def test_LogSoftmax():
995996

996997
if __name__ == '__main__':
997998
test_reshape()
998-
test_reshape_like()
999+
test_shape()
9991000
test_power()
10001001
test_squeeze()
10011002
test_unsqueeze()

0 commit comments

Comments
 (0)