Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Auto extract onnx input shapes when possible. (
Browse files Browse the repository at this point in the history
…#7115)

* Auto extract onnx input shapes when possible.

* Remove shape dict definition in tvmc.
  • Loading branch information
jwfromm authored Dec 16, 2020
1 parent 0a3e178 commit 18cf9b9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 28 deletions.
11 changes: 1 addition & 10 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,7 @@ def load(self, path):
# pylint: disable=E1101
model = onnx.load(path)

# pylint: disable=E1101
name = model.graph.input[0].name

# pylint: disable=E1101
proto_shape = model.graph.input[0].type.tensor_type.shape.dim
shape = [d.dim_value for d in proto_shape]

shape_dict = {name: shape}

return relay.frontend.from_onnx(model, shape_dict)
return relay.frontend.from_onnx(model)


class TensorflowFrontend(Frontend):
Expand Down
37 changes: 19 additions & 18 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,20 @@ def get_type(elem_type):
def get_info(info_proto):
"""Extract the shape from a ValueInfoProto."""
shape = []
shape_name = []
for dim in info_proto.type.tensor_type.shape.dim:
name = dim.dim_param
value = dim.dim_value
if value is None:
value = _ty.Any
if value is None or value == 0:
value = _ty.Any()
shape_name.append(name)
else:
shape_name.append(value)
shape.append(value)

name = info_proto.name
dtype = get_type(info_proto.type.tensor_type.elem_type)
return name, shape, dtype
return name, shape, dtype, shape_name


def dimension_picker(prefix, suffix=""):
Expand Down Expand Up @@ -2185,7 +2190,7 @@ def get_var(name, val, scan=False):
scan_output_vars = []
scan_output_init = []
for i in range(num_scan_outputs):
name, shape, dtype = get_info(body.output[i + 1 + num_deps])
name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps])
scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + shape), dtype=dtype))
scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape))

Expand Down Expand Up @@ -2829,8 +2834,7 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
i_name = self._parse_value_proto(i)
d_type = self._parse_dtype(i, "float32")
i_name, i_shape, d_type, i_shape_name = get_info(i)
if i_name in self._params:
# i is a param instead of input
self._num_param += 1
Expand All @@ -2841,14 +2845,20 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
else:
self._num_input += 1
if i_name in self._shape:
tshape = self._shape[i_name]
i_shape = self._shape[i_name]
else:
raise ValueError("Must provide an input shape for `{0}`.".format(i_name))
if "?" in str(i_shape):
warning_msg = (
"Input %s has unknown dimension shapes: %s. "
"Specifying static values may improve performance"
% (i_name, str(i_shape_name))
)
warnings.warn(warning_msg)
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype)
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
self._inputs[i_name] = self._nodes[i_name]
# get list of unsupported ops
convert_map = _get_convert_map(opset)
Expand Down Expand Up @@ -2935,15 +2945,6 @@ def _parse_value_proto(self, value_proto):
name = value_proto
return name

def _parse_dtype(self, value_proto, dtype):
"""Parse dtype."""
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE

return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
except AttributeError:
return dtype

def _parse_array(self, tensor_proto):
np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))
return _nd.array(np_array)
Expand Down

0 comments on commit 18cf9b9

Please sign in to comment.