@@ -106,7 +106,7 @@ def _impl_v1(cls, inputs, attr, params):
106106 'pads' : ('padding' , (0 , 0 ), revert_caffe2_pad )
107107 },
108108 # very weird attributes here in onnx, force check
109- ignores = ['dilations' ],
109+ ignores = ['dilations' , 'auto_pad' ],
110110 # TODO(zhreshold): make sure ceil_mode in onnx, and layout?
111111 extras = {'ceil_mode' : False },
112112 custom_check = dimension_constraint ())(inputs , attr , params )
@@ -160,6 +160,7 @@ def _impl_v1(cls, inputs, attr, params):
160160 'dilations' : ('dilation' , (0 , 0 )),
161161 'pads' : ('padding' , (0 , 0 ), revert_caffe2_pad ),
162162 'group' : ('groups' , 1 )},
163+ ignores = ['auto_pad' ],
163164 custom_check = dimension_constraint ())(inputs [:2 ], attr , params )
164165 use_bias = len (inputs ) == 3
165166 if use_bias :
@@ -332,7 +333,21 @@ def _impl_v1(cls, inputs, attr, params):
332333 shape = tuple (params [inputs [1 ].name_hint ].asnumpy ())
333334 out = _op .reshape (inputs [0 ], shape )
334335 else :
335- out = _op .reshape_like (inputs [0 ], inputs [1 ])
336+ # Try to infer shape by precompute prune if possible.
337+ # TODO: good to check inputs to be in params.
338+ # to be enhanced when relay support list_input_names API of NNVM
339+ logging .warning ("Infering Reshape argument by precompute" )
340+ func = _expr .Function (ir_pass .free_vars (inputs [1 ]), inputs [1 ])
341+ with tvm .relay .build_config (opt_level = 0 ):
342+ graph , lib , params = tvm .relay .build (func , target = "llvm" , params = params )
343+ ctx = tvm .context ("llvm" , 0 )
344+ from tvm .contrib import graph_runtime
345+ m = graph_runtime .create (graph , lib , ctx )
346+ m .set_input (** params )
347+ m .run ()
348+ params_new = m .get_output (0 )
349+ inputs .pop (1 )
350+ out = _op .reshape (inputs [0 ], tuple (params_new .asnumpy ().astype ('int32' ).flatten ()))
336351
337352 return out
338353
@@ -477,10 +492,7 @@ class Shape(OnnxOpConverter):
477492
478493 @classmethod
479494 def _impl_v1 (cls , inputs , attr , params ):
480- # Result of this operator is prominently used by reshape operator.
481- # Just pass the input as it is so that reshape_like can be used there.
482- logging .warning ("Shape: Differently implemented in relay as a bypass (dummy operator)" )
483- return inputs [0 ]
495+ return _op .shape_of (inputs [0 ])
484496
485497class Cast (OnnxOpConverter ):
486498 """ Operator converter for Cast.
@@ -494,7 +506,7 @@ def _impl_v1(cls, inputs, attr, params):
494506 def _impl_v5 (cls , inputs , attr , params ):
495507 try :
496508 from onnx .mapping import TENSOR_TYPE_TO_NP_TYPE
497- attr ['to' ] = TENSOR_TYPE_TO_NP_TYPE [attr ['to' ]]
509+ attr ['to' ] = str ( TENSOR_TYPE_TO_NP_TYPE [attr ['to' ]])
498510 except ImportError as e :
499511 raise ImportError (
500512 "Unable to import onnx.mapping which is required {}" .format (e ))
@@ -674,6 +686,11 @@ class ReduceMean(Reduce):
674686 """
675687 name = 'mean'
676688
689+ class ReduceProd (Reduce ):
690+ """ Operator converter for ArgMax.
691+ """
692+ name = 'prod'
693+
677694class ArgMax (OnnxOpConverter ):
678695 """ Operator converter for ArgMax.
679696 """
@@ -826,6 +843,7 @@ def _get_convert_map(opset):
826843 'ReduceMin' : ReduceMin .get_converter (opset ),
827844 'ReduceSum' : ReduceSum .get_converter (opset ),
828845 'ReduceMean' : ReduceMean .get_converter (opset ),
846+ 'ReduceProd' : ReduceProd .get_converter (opset ),
829847 # 'ReduceProd'
830848 # 'ReduceLogSumExp'
831849 'ArgMax' : ArgMax .get_converter (opset ),
@@ -842,8 +860,7 @@ def _get_convert_map(opset):
842860 'Squeeze' : AttrCvt ('squeeze' , {'axes' : 'axis' }),
843861 'Unsqueeze' : Unsqueeze .get_converter (opset ),
844862 'Pad' : Pad .get_converter (opset ),
845- # TODO(zhreshold) Shape op is implemented as bypass op in relay
846- # 'Shape': Shape.get_converter(opset),
863+ 'Shape' : Shape .get_converter (opset ),
847864 }
848865
849866
@@ -883,6 +900,7 @@ def from_onnx(self, graph, opset):
883900 ----------
884901 graph : onnx protobuf object
885902 The loaded onnx graph
903+
886904 opset : opset version
887905
888906 Returns
@@ -911,12 +929,12 @@ def from_onnx(self, graph, opset):
911929 dtype = self ._params [i_name ].dtype )
912930 else :
913931 self ._num_input += 1
914- shape = self ._shape [i_name ] if i_name in self ._shape else ()
932+ tshape = self ._shape [i_name ] if i_name in self ._shape else ()
915933 if isinstance (self ._dtype , dict ):
916934 dtype = self ._dtype [i_name ] if i_name in self ._dtype else d_type
917935 else :
918936 dtype = d_type
919- self ._nodes [i_name ] = new_var (i_name , shape = shape , dtype = dtype )
937+ self ._nodes [i_name ] = new_var (i_name , shape = tshape , dtype = dtype )
920938 # construct nodes, nodes are stored as directed acyclic graph
921939 for node in graph .node :
922940 op_name = node .op_type
@@ -936,6 +954,10 @@ def from_onnx(self, graph, opset):
936954 self ._nodes [i_name ] = new_var (node .output [0 ], shape = (), dtype = dtype )
937955 inputs .append (self ._nodes [i_name ])
938956
957+ i_name = self ._parse_value_proto (node )
958+ attr ['tvm_custom' ] = {}
959+ attr ['tvm_custom' ]['name' ] = i_name
960+
939961 op = self ._convert_operator (op_name , inputs , attr , opset )
940962 node_output = self ._fix_outputs (op_name , node .output )
941963 if not isinstance (op , _expr .TupleWrapper ):
0 commit comments