@@ -2004,7 +2004,7 @@ class GraphProto():
2004
2004
def __init__ (self , shape , dtype ):
2005
2005
self ._nodes = {}
2006
2006
self ._params = {}
2007
- self ._inputs = []
2007
+ self ._inputs = {}
2008
2008
self ._renames = {}
2009
2009
self ._num_input = 0
2010
2010
self ._num_param = 0
@@ -2075,7 +2075,7 @@ def from_onnx(self, graph, opset, freeze_params=False):
2075
2075
else :
2076
2076
dtype = d_type
2077
2077
self ._nodes [i_name ] = new_var (i_name , shape = tshape , dtype = dtype )
2078
- self ._inputs . append ( self ._nodes [i_name ])
2078
+ self ._inputs [ i_name ] = self ._nodes [i_name ]
2079
2079
# get list of unsupported ops
2080
2080
convert_map = _get_convert_map (opset )
2081
2081
unsupported_ops = set ()
@@ -2131,8 +2131,15 @@ def from_onnx(self, graph, opset, freeze_params=False):
2131
2131
# now return the outputs
2132
2132
outputs = [self ._nodes [self ._parse_value_proto (i )] for i in graph .output ]
2133
2133
outputs = outputs [0 ] if len (outputs ) == 1 else _expr .Tuple (outputs )
2134
-
2135
- func = _function .Function (self ._inputs , outputs )
2134
+ ## Maintain the order of inputs and parametersfrom the ONNX graph, but only include
2135
+ ## those parameters that are needed to execute the relay graph
2136
+ free_vars = analysis .free_vars (outputs )
2137
+ nodes = {v :k for k ,v in self ._nodes .items ()}
2138
+ free_vars = [nodes [var ] for var in free_vars ]
2139
+ for i_name in self ._params :
2140
+ if i_name in free_vars and i_name not in self ._inputs :
2141
+ self ._inputs [i_name ] = self ._nodes [i_name ]
2142
+ func = _function .Function ([v for k ,v in self ._inputs .items ()], outputs )
2136
2143
if freeze_params :
2137
2144
func , params = self .freeze (func , self ._params )
2138
2145
return IRModule .from_expr (func ), params
0 commit comments