Skip to content

Commit 7d0e616

Browse files
Remove legacy optimizer (#2180)
- Remove legacy optimizer and support proto inputs with the IR based optimizer. - Add a new `inline=True` option in `optimize()` to control whether function inlining is done when optimizing - Implement identity folding for graph outputs - Migrate constant folding tests to run on IR models Fix #2185 --------- Co-authored-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 883a74f commit 7d0e616

14 files changed

+329
-1760
lines changed

onnxscript/optimizer/__init__.py

+77-15
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,90 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5+
from typing import TypeVar
6+
57
__all__ = [
6-
"fold_constants",
7-
"fold_constants_ir",
8-
"remove_unused_nodes",
9-
"optimize",
10-
"optimize_ir",
118
"basic_constant_propagation",
9+
"fold_constants_ir",
10+
"fold_constants",
1211
"inline",
12+
"optimize_ir",
13+
"optimize",
14+
"remove_unused_nodes",
1315
]
1416

1517
import onnx
1618

1719
import onnxscript.ir.passes.common.inliner
1820
import onnxscript.ir.passes.common.unused_removal
1921
import onnxscript.optimizer._constant_folding as constant_folding
20-
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
21-
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
2222
from onnxscript import ir
23+
from onnxscript.optimizer._constant_folding import (
24+
basic_constant_propagation,
25+
)
26+
from onnxscript.optimizer._constant_folding import (
27+
fold_constants as fold_constants_ir,
28+
)
2329
from onnxscript.optimizer._optimizer import optimize_ir
2430

25-
basic_constant_propagation = constant_folding.basic_constant_propagation
26-
fold_constants_ir = constant_folding.fold_constants
31+
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
32+
2733

34+
def optimize(
35+
model: _ModelProtoOrIr,
36+
num_iterations: int = 2,
37+
*,
38+
onnx_shape_inference: bool = True,
39+
stop_if_no_change: bool = True,
40+
input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
41+
output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
42+
inline: bool = True,
43+
) -> _ModelProtoOrIr:
44+
"""Optimizes a model.
2845
29-
def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
46+
Args:
47+
model: The model to be optimized.
48+
num_iterations: Number of times the optimization loop is repeated.
49+
onnx_shape_inference: Applies node-level shape-inference as part of optimization
50+
input_size_limit: Will not apply constant folding to ops with any input of size
51+
greater than this. Does not apply to special ops like Shape() and Size().
52+
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
53+
of the output tensor is greater than this.
54+
stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
55+
inline: If True, inlines all functions in the model.
56+
57+
Returns:
58+
The optimized model. If the input was a ModelProto, the output will also be a
59+
ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
60+
"""
3061
if isinstance(model, ir.Model):
31-
# In that case, this is done inplace.
32-
optimize_ir(model, *args, **kwargs)
62+
# In this case, optimize is done inplace.
63+
# TODO(justinchuby): Maybe make functional
64+
optimize_ir(
65+
model,
66+
num_iterations=num_iterations,
67+
onnx_shape_inference=onnx_shape_inference,
68+
stop_if_no_change=stop_if_no_change,
69+
input_size_limit=input_size_limit,
70+
output_size_limit=output_size_limit,
71+
inline=inline,
72+
)
3373
return model
3474
else:
35-
return legacy_optimizer.optimize(model, *args, **kwargs)
75+
assert isinstance(model, onnx.ModelProto)
76+
model_ir = ir.serde.deserialize_model(model)
77+
optimize_ir(
78+
model_ir,
79+
num_iterations=num_iterations,
80+
onnx_shape_inference=onnx_shape_inference,
81+
stop_if_no_change=stop_if_no_change,
82+
input_size_limit=input_size_limit,
83+
output_size_limit=output_size_limit,
84+
inline=inline,
85+
)
86+
# Move the model back to the proto
87+
new_proto = ir.serde.serialize_model(model_ir)
88+
return new_proto
3689

3790

3891
def inline(model: ir.Model) -> None:
@@ -43,11 +96,20 @@ def inline(model: ir.Model) -> None:
4396

4497
def fold_constants(
4598
model: ir.Model | onnx.ModelProto, *args, **kwargs
46-
) -> constant_folding.FoldConstantsResult | bool:
99+
) -> constant_folding.FoldConstantsResult:
100+
"""Fold constants in a model in place."""
47101
if isinstance(model, ir.Model):
48102
return constant_folding.fold_constants(model, *args, **kwargs)
49103
else:
50-
return legacy_constant_folding.fold_constants(model, *args, **kwargs)
104+
assert isinstance(model, onnx.ModelProto)
105+
model_proto = model
106+
model = ir.serde.deserialize_model(model_proto)
107+
result = constant_folding.fold_constants(model, *args, **kwargs)
108+
# Move the model back to the proto
109+
new_proto = ir.serde.serialize_model(model)
110+
model_proto.Clear()
111+
model_proto.CopyFrom(new_proto)
112+
return result
51113

52114

53115
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:

onnxscript/optimizer/_constant_folding.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
919919
e,
920920
)
921921

922-
def new_constant(self, node: ir.Node, value):
922+
def new_constant(self, node: ir.Node, value) -> ir.Node | None:
923923
irvalue = node.outputs[0]
924924
if not isinstance(value, np.ndarray):
925925
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
@@ -965,7 +965,7 @@ def new_constant(self, node: ir.Node, value):
965965
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
966966
return node
967967

968-
def process_node(self, node: ir.Node):
968+
def process_node(self, node: ir.Node) -> Replacement | None:
969969
for i, value in enumerate(node.inputs):
970970
sym_value = self._state.get_sym_value(value)
971971
if isinstance(sym_value, ir.Value):
@@ -1046,7 +1046,7 @@ def convert(av):
10461046
)
10471047
return None
10481048

1049-
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function):
1049+
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
10501050
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
10511051

10521052
ir.convenience.replace_nodes_and_values(
@@ -1066,13 +1066,13 @@ def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None:
10661066
for graph in attr.as_graphs():
10671067
self.visit_graph(graph)
10681068

1069-
def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function):
1069+
def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
10701070
replacement = self.process_node(node)
10711071
if replacement is None:
10721072
# No change. Process attributes.
10731073
for attr in node.attributes.values():
10741074
self.visit_attribute(attr)
1075-
return None
1075+
return
10761076
else:
10771077
self.replace_node(node, replacement, root)
10781078

@@ -1087,6 +1087,22 @@ def visit_graph(self, graph: ir.Graph) -> None:
10871087
for node in graph:
10881088
self.visit_node(node, graph)
10891089

1090+
# Replace outputs if output nodes can be folded. This are typically outputs from
1091+
# Identity nodes
1092+
for i, output in enumerate(graph.outputs):
1093+
if output is None:
1094+
continue
1095+
sym_value = self._state.get_sym_value(output)
1096+
if not isinstance(sym_value, ir.Value):
1097+
# An output must be a Value
1098+
continue
1099+
if not _sym_value_can_replace_graph_output(graph, sym_value, output):
1100+
continue
1101+
# Rename sym_value to match the output name
1102+
sym_value.name = output.name
1103+
graph.outputs[i] = sym_value
1104+
self.modified = True
1105+
10901106
self._state.pop_initializer_inputs()
10911107

10921108
def visit_function(self, function: ir.Function) -> None:
@@ -1103,6 +1119,24 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
11031119
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)
11041120

11051121

1122+
def _sym_value_can_replace_graph_output(
1123+
graph: ir.Graph, sym_value: ir.Value, output: ir.Value
1124+
) -> bool:
1125+
if (producer := sym_value.producer()) is None:
1126+
# If the sym_value has no producer, it is some graph's input
1127+
# ONNX does not allow a graph input to be a graph output
1128+
return False
1129+
if producer.graph is not graph:
1130+
# The sym_value must be produced by a node in the graph to be an output of this graph
1131+
return False
1132+
if sym_value.is_graph_output():
1133+
# If the sym_value is already an output of a graph, we cannot rename it
1134+
# to this output name. Otherwise the graph output represented by sym_value
1135+
# will lose its name.
1136+
return False
1137+
return True
1138+
1139+
11061140
@dataclasses.dataclass
11071141
class FoldConstantsResult(ir.passes.PassResult):
11081142
symbolic_value_map: dict[ir.Value, SymbolicValue]

0 commit comments

Comments
 (0)