Skip to content

Commit

Permalink
[ONNX] Improve scope inference in function extraction
Browse files Browse the repository at this point in the history
Cover more cases of scope inferencing where consecutive nodes don't have valid scope information. Usually these nodes are created in some pass where authors forgot to assign meaningful scope to them.
* One rule of `InferScope` is to check if the current node's outputs' users share the same scope. Recursively run `InferScope` on the user nodes if they are missing scope as well. Since the graph is SSA, the depth is finite.
* Fix one pass that missed scope information for a new node.
Pull Request resolved: pytorch#71897
  • Loading branch information
BowenBao authored and pytorchmergebot committed Jan 31, 2022
1 parent a83cf17 commit cf70466
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 3 deletions.
22 changes: 22 additions & 0 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,24 @@ def forward(self, x, y, z):
self.assertIn("NWithOverloads.1", func_names)
self.assertIn("NWithOverloads.2", func_names)

@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function_infer_scopes(self):
class M(torch.nn.Module):
def forward(self, x):
# Concatenation of scalars inserts unscoped tensors in IR graph.
new_tensor_shape = x.size()[:-1] + (1, 1, -1)
tensor = x.view(*new_tensor_shape)
return tensor

x = torch.randn(4, 5)
f = io.BytesIO()
torch.onnx.export(M(), (x,), f, export_modules_as_functions=True,
opset_version=self.opset_version, do_constant_folding=False)

onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
self.assertIn("M", [f.name for f in funcs])

def test_aten_fallthrough(self):
# Test aten export of op with no symbolic
class Module(torch.nn.Module):
Expand Down Expand Up @@ -1222,5 +1240,9 @@ class TestUtilityFuns_opset14(TestUtilityFuns_opset9):
opset_version = 14


class TestUtilityFuns_opset15(TestUtilityFuns_opset9):
opset_version = 15


if __name__ == "__main__":
run_tests()
7 changes: 7 additions & 0 deletions torch/csrc/jit/passes/onnx/function_extraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,13 @@ c10::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
}
for (auto output : n->outputs()) {
for (auto use : output->uses()) {
if (!IsValidScope(use.user->scope())) {
auto inferred_output_scope = InferScope(use.user);
if (inferred_output_scope.has_value() &&
IsValidScope(inferred_output_scope.value())) {
use.user->setScope(inferred_output_scope.value());
}
}
output_scopes.emplace_back(use.user->scope());
}
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/onnx/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ Node* transformToONNXConcatNode(

Node* unsqueezed_node =
createONNXUnsqueeze(g, new_node, new_input, 0, opset_version);
unsqueezed_node->copyMetadata(lc_node);
unsqueezed.emplace_back(unsqueezed_node->output());
}

Expand Down
2 changes: 0 additions & 2 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

producer_name = "pytorch"
producer_version = _C._onnx.PRODUCER_VERSION
constant_folding_opset_versions = [9, 10, 11, 12, 13, 14]


class ExportTypes:
r""""Specifies how the ONNX model is stored."""
Expand Down
1 change: 1 addition & 0 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def _handle_reduce_dim_none(g, self, op_name):
_onnx_main_opset = 15
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14]
_export_onnx_opset_version = _default_onnx_opset_version
_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))


def _set_opset_version(opset_version):
Expand Down
3 changes: 2 additions & 1 deletion torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ def _model_to_graph(model, args, verbose=False,
if training is None or training == TrainingMode.EVAL:
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)

if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
from torch.onnx.symbolic_helper import _constant_folding_opset_versions
if do_constant_folding and _export_onnx_opset_version in _constant_folding_opset_versions:
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
_export_onnx_opset_version)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
Expand Down

0 comments on commit cf70466

Please sign in to comment.