From cf7046697064db44ed573f6fe21ec657ccb28054 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 31 Jan 2022 23:58:53 +0000 Subject: [PATCH] [ONNX] Improve scope inference in function extraction Cover more cases of scope inferencing where consecutive nodes don't have valid scope information. Usually these nodes are created in some pass where authors forgot to assign meaningful scope to them. * One rule of `InferScope` is to check if the current node's outputs' users share the same scope. Recursively run `InferScope` on the user nodes if they are missing scope as well. Since the graph is SSA, the depth is finite. * Fix one pass that missed scope information for a new node. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71897 --- test/onnx/test_utility_funs.py | 22 +++++++++++++++++++ .../jit/passes/onnx/function_extraction.cpp | 7 ++++++ torch/csrc/jit/passes/onnx/helper.cpp | 1 + torch/onnx/__init__.py | 2 -- torch/onnx/symbolic_helper.py | 1 + torch/onnx/utils.py | 3 ++- 6 files changed, 33 insertions(+), 3 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index d7bef66811801..dca45fc5c3114 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -752,6 +752,24 @@ def forward(self, x, y, z): self.assertIn("NWithOverloads.1", func_names) self.assertIn("NWithOverloads.2", func_names) + @skipIfUnsupportedMinOpsetVersion(15) + def test_local_function_infer_scopes(self): + class M(torch.nn.Module): + def forward(self, x): + # Concatenation of scalars inserts unscoped tensors in IR graph. + new_tensor_shape = x.size()[:-1] + (1, 1, -1) + tensor = x.view(*new_tensor_shape) + return tensor + + x = torch.randn(4, 5) + f = io.BytesIO() + torch.onnx.export(M(), (x,), f, export_modules_as_functions=True, + opset_version=self.opset_version, do_constant_folding=False) + + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + funcs = onnx_model.functions + self.assertIn("M", [f.name for f in funcs]) + def test_aten_fallthrough(self): # Test aten export of op with no symbolic class Module(torch.nn.Module): @@ -1222,5 +1240,9 @@ class TestUtilityFuns_opset14(TestUtilityFuns_opset9): opset_version = 14 +class TestUtilityFuns_opset15(TestUtilityFuns_opset9): + opset_version = 15 + + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index 5a0f8592f3bd8..1840d96fd13a7 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -353,6 +353,13 @@ c10::optional FunctionExtractor::InferScope(Node* n) { } for (auto output : n->outputs()) { for (auto use : output->uses()) { + if (!IsValidScope(use.user->scope())) { + auto inferred_output_scope = InferScope(use.user); + if (inferred_output_scope.has_value() && + IsValidScope(inferred_output_scope.value())) { + use.user->setScope(inferred_output_scope.value()); + } + } output_scopes.emplace_back(use.user->scope()); } } diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 83206935048a0..f76b606c18122 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -189,6 +189,7 @@ Node* transformToONNXConcatNode( Node* unsqueezed_node = createONNXUnsqueeze(g, new_node, new_input, 0, opset_version); + unsqueezed_node->copyMetadata(lc_node); unsqueezed.emplace_back(unsqueezed_node->output()); } diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 31bffdabca1e4..2a049f37e35dc 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -9,8 +9,6 @@ producer_name = "pytorch" producer_version = _C._onnx.PRODUCER_VERSION -constant_folding_opset_versions = [9, 10, 11, 12, 13, 14] - class ExportTypes: r""""Specifies how the ONNX model is stored.""" diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 3d62dc2df3cbd..d6746967d76b2 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -839,6 +839,7 @@ def _handle_reduce_dim_none(g, self, op_name): _onnx_main_opset = 15 _onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14] _export_onnx_opset_version = _default_onnx_opset_version +_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1)) def _set_opset_version(opset_version): diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index ad960a4b3c168..50023313ee74b 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -532,7 +532,8 @@ def _model_to_graph(model, args, verbose=False, if training is None or training == TrainingMode.EVAL: params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) - if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions: + from torch.onnx.symbolic_helper import _constant_folding_opset_versions + if do_constant_folding and _export_onnx_opset_version in _constant_folding_opset_versions: params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict, _export_onnx_opset_version) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)