Skip to content

Commit adfdcda

Browse files
author
Matthew Brookhart
committed
handle onnx graph initializer parameters more intelligently
1 parent 6a027a3 commit adfdcda

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,7 +2004,7 @@ class GraphProto():
20042004
def __init__(self, shape, dtype):
20052005
self._nodes = {}
20062006
self._params = {}
2007-
self._inputs = []
2007+
self._inputs = {}
20082008
self._renames = {}
20092009
self._num_input = 0
20102010
self._num_param = 0
@@ -2075,7 +2075,7 @@ def from_onnx(self, graph, opset, freeze_params=False):
20752075
else:
20762076
dtype = d_type
20772077
self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
2078-
self._inputs.append(self._nodes[i_name])
2078+
self._inputs[i_name] = self._nodes[i_name]
20792079
# get list of unsupported ops
20802080
convert_map = _get_convert_map(opset)
20812081
unsupported_ops = set()
@@ -2131,8 +2131,15 @@ def from_onnx(self, graph, opset, freeze_params=False):
21312131
# now return the outputs
21322132
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
21332133
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
2134-
2135-
func = _function.Function(self._inputs, outputs)
2134+
## Maintain the order of inputs and parametersfrom the ONNX graph, but only include
2135+
## those parameters that are needed to execute the relay graph
2136+
free_vars = analysis.free_vars(outputs)
2137+
nodes = {v:k for k,v in self._nodes.items()}
2138+
free_vars = [nodes[var] for var in free_vars]
2139+
for i_name in self._params:
2140+
if i_name in free_vars and i_name not in self._inputs:
2141+
self._inputs[i_name] = self._nodes[i_name]
2142+
func = _function.Function([v for k,v in self._inputs.items()], outputs)
21362143
if freeze_params:
21372144
func, params = self.freeze(func, self._params)
21382145
return IRModule.from_expr(func), params

0 commit comments

Comments
 (0)