Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,

origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")

new_proto = GraphUtil.optimize_model_proto(origin_proto, catch_errors=False)
new_proto, new_graph = GraphUtil.optimize_model_proto(origin_proto, catch_errors=False, return_graph=True)

self.assertTrue(new_proto, msg="model proto after optimizer should not be None")

Expand All @@ -52,6 +52,8 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
self.assertEqual(expected_val.dtype, actual_val.dtype)
self.assertEqual(expected_val.shape, actual_val.shape)

self.assert_shapes_correct(new_graph, allow_missing=False, run_checker=True)

return new_proto

@staticmethod
Expand Down
8 changes: 6 additions & 2 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,11 +1565,11 @@ def optimize_graph(graph, catch_errors=True):
return optimizer.optimize_graph(graph, catch_errors)

@staticmethod
def optimize_model_proto(onnx_model_proto, catch_errors=True):
def optimize_model_proto(onnx_model_proto, catch_errors=True, return_graph=False):
"""Optimize the model proto, for example: eliminating all useless Transpose pairs.

Returns:
model proto after optimization, if optimizer run successfully
model proto (and possibly graph) after optimization, if optimizer run successfully
or onnx_model_proto, if exceptions happens
"""
try:
Expand All @@ -1582,13 +1582,17 @@ def optimize_model_proto(onnx_model_proto, catch_errors=True):
if onnx_model_proto.metadata_props:
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
helper.set_model_props(model_proto, metadata_props)
if return_graph:
return model_proto, graph
return model_proto
except Exception as e:
if not catch_errors:
raise e
# sometimes, onnx shape inference will fail for some reason,
# return onnx_model_proto for this case
logger.warning("Failed to optimize model proto", exc_info=1)
if return_graph:
return onnx_model_proto, None
return onnx_model_proto

@staticmethod
Expand Down