2
2
# Licensed under the MIT License.
3
3
from __future__ import annotations
4
4
5
+ from typing import TypeVar
6
+
5
7
__all__ = [
6
- "fold_constants" ,
7
- "fold_constants_ir" ,
8
- "remove_unused_nodes" ,
9
- "optimize" ,
10
- "optimize_ir" ,
11
8
"basic_constant_propagation" ,
9
+ "fold_constants_ir" ,
10
+ "fold_constants" ,
12
11
"inline" ,
12
+ "optimize_ir" ,
13
+ "optimize" ,
14
+ "remove_unused_nodes" ,
13
15
]
14
16
15
17
import onnx
16
18
17
19
import onnxscript .ir .passes .common .inliner
18
20
import onnxscript .ir .passes .common .unused_removal
19
21
import onnxscript .optimizer ._constant_folding as constant_folding
20
- import onnxscript .optimizer ._legacy ._optimizer as legacy_optimizer
21
- import onnxscript .optimizer ._legacy .constant_folding as legacy_constant_folding
22
22
from onnxscript import ir
23
+ from onnxscript .optimizer ._constant_folding import (
24
+ basic_constant_propagation ,
25
+ )
26
+ from onnxscript .optimizer ._constant_folding import (
27
+ fold_constants as fold_constants_ir ,
28
+ )
23
29
from onnxscript .optimizer ._optimizer import optimize_ir
24
30
25
- basic_constant_propagation = constant_folding . basic_constant_propagation
26
- fold_constants_ir = constant_folding . fold_constants
31
+ _ModelProtoOrIr = TypeVar ( "_ModelProtoOrIr" , onnx . ModelProto , ir . Model )
32
+
27
33
34
+ def optimize (
35
+ model : _ModelProtoOrIr ,
36
+ num_iterations : int = 2 ,
37
+ * ,
38
+ onnx_shape_inference : bool = True ,
39
+ stop_if_no_change : bool = True ,
40
+ input_size_limit : int = constant_folding .DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
41
+ output_size_limit : int = constant_folding .DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
42
+ inline : bool = True ,
43
+ ) -> _ModelProtoOrIr :
44
+ """Optimizes a model.
28
45
29
- def optimize (model : ir .Model , * args , ** kwargs ) -> ir .Model :
46
+ Args:
47
+ model: The model to be optimized.
48
+ num_iterations: Number of times the optimization loop is repeated.
49
+ onnx_shape_inference: Applies node-level shape-inference as part of optimization
50
+ input_size_limit: Will not apply constant folding to ops with any input of size
51
+ greater than this. Does not apply to special ops like Shape() and Size().
52
+ output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
53
+ of the output tensor is greater than this.
54
+ stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
55
+ inline: If True, inlines all functions in the model.
56
+
57
+ Returns:
58
+ The optimized model. If the input was a ModelProto, the output will also be a
59
+ ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
60
+ """
30
61
if isinstance (model , ir .Model ):
31
- # In that case, this is done inplace.
32
- optimize_ir (model , * args , ** kwargs )
62
+ # In this case, optimize is done inplace.
63
+ # TODO(justinchuby): Maybe make functional
64
+ optimize_ir (
65
+ model ,
66
+ num_iterations = num_iterations ,
67
+ onnx_shape_inference = onnx_shape_inference ,
68
+ stop_if_no_change = stop_if_no_change ,
69
+ input_size_limit = input_size_limit ,
70
+ output_size_limit = output_size_limit ,
71
+ inline = inline ,
72
+ )
33
73
return model
34
74
else :
35
- return legacy_optimizer .optimize (model , * args , ** kwargs )
75
+ assert isinstance (model , onnx .ModelProto )
76
+ model_ir = ir .serde .deserialize_model (model )
77
+ optimize_ir (
78
+ model_ir ,
79
+ num_iterations = num_iterations ,
80
+ onnx_shape_inference = onnx_shape_inference ,
81
+ stop_if_no_change = stop_if_no_change ,
82
+ input_size_limit = input_size_limit ,
83
+ output_size_limit = output_size_limit ,
84
+ inline = inline ,
85
+ )
86
+ # Move the model back to the proto
87
+ new_proto = ir .serde .serialize_model (model_ir )
88
+ return new_proto
36
89
37
90
38
91
def inline (model : ir .Model ) -> None :
@@ -43,11 +96,20 @@ def inline(model: ir.Model) -> None:
43
96
44
97
def fold_constants (
45
98
model : ir .Model | onnx .ModelProto , * args , ** kwargs
46
- ) -> constant_folding .FoldConstantsResult | bool :
99
+ ) -> constant_folding .FoldConstantsResult :
100
+ """Fold constants in a model in place."""
47
101
if isinstance (model , ir .Model ):
48
102
return constant_folding .fold_constants (model , * args , ** kwargs )
49
103
else :
50
- return legacy_constant_folding .fold_constants (model , * args , ** kwargs )
104
+ assert isinstance (model , onnx .ModelProto )
105
+ model_proto = model
106
+ model = ir .serde .deserialize_model (model_proto )
107
+ result = constant_folding .fold_constants (model , * args , ** kwargs )
108
+ # Move the model back to the proto
109
+ new_proto = ir .serde .serialize_model (model )
110
+ model_proto .Clear ()
111
+ model_proto .CopyFrom (new_proto )
112
+ return result
51
113
52
114
53
115
def remove_unused_nodes (model : ir .Model | onnx .ModelProto ) -> None :
0 commit comments