-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cleanup optimizer by moving older proto-based optimizations into a _legacy folder, renaming files to distinguish internal implementation files, and other minor restructuring.
- Loading branch information
1 parent
8fef233
commit 4578142
Showing
20 changed files
with
225 additions
and
464 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Any | ||
|
||
import onnx | ||
import onnx.shape_inference | ||
|
||
from onnxscript import rewriter | ||
from onnxscript.optimizer._legacy._simple_function_folding import ( | ||
inline_functions_with_unused_outputs, | ||
inline_simple_functions, | ||
) | ||
from onnxscript.optimizer._legacy.constant_folding import fold_constants | ||
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES | ||
from onnxscript.optimizer._remove_unused import remove_unused_nodes | ||
from onnxscript.optimizer._remove_unused_function import remove_unused_functions | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def optimize( | ||
model: onnx.ModelProto, | ||
num_iterations: int = 2, | ||
*, | ||
onnx_shape_inference: bool = True, | ||
stop_if_no_change: bool = True, | ||
external_data_folder: str = "", | ||
**kwargs: Any, | ||
) -> onnx.ModelProto: | ||
"""Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. | ||
Args: | ||
model (onnx.ModelProto): The model to optimize. | ||
num_iterations (int, optional): Number of iterations to perform. | ||
onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. | ||
Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. | ||
This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries | ||
the symbolic shapes recorded from dynamo tracing. | ||
stop_if_no_change (bool, optional): Whether to stop if no change is detected. | ||
external_data_folder (str, optional): The folder to store external data. | ||
**kwargs: Additional keyword arguments. For BC purposes. | ||
""" | ||
if kwargs.pop("function_aware_folding", None) is not None: | ||
logger.warning( | ||
"'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " | ||
"To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " | ||
"This would turn off incremental onnx shape inference and rely on model carried shapes and types. " | ||
"See 'onnx_shape_inference' for more details." | ||
) | ||
for _ in range(num_iterations): | ||
if onnx_shape_inference: | ||
if model.ByteSize() < 1024 * 1024 * 1024 * 2: | ||
# NOTE: strict mode is disabled because it crashes on the models | ||
# that have different shapes inferred from the model carried shapes. | ||
# The case can be found in: | ||
# https://github.com/microsoft/onnxscript/issues/1443 | ||
model = onnx.shape_inference.infer_shapes( | ||
model, check_type=True, strict_mode=False, data_prop=True | ||
) | ||
else: | ||
logger.warning( | ||
"The model size is too large for full model shape inference. " | ||
"Skipping this step." | ||
) | ||
|
||
inline_simple_functions(model) | ||
modified = fold_constants( | ||
model, external_data_folder, onnx_shape_inference=onnx_shape_inference | ||
) | ||
|
||
remove_unused_nodes(model) | ||
inline_simple_functions(model) | ||
model = remove_unused_functions(model) | ||
inline_functions_with_unused_outputs(model) | ||
# NOTE: This is general rewrite rules | ||
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) | ||
if stop_if_no_change and not modified: | ||
logger.debug("Stopping after %d iterations.", _) | ||
break | ||
|
||
for node in model.graph.node: | ||
logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) | ||
|
||
for function in model.functions: | ||
for node in function.node: | ||
logger.debug( | ||
"Function %s::%s node %s::%s name %s.", | ||
function.domain, | ||
function.name, | ||
node.domain, | ||
node.op_type, | ||
node.name, | ||
) | ||
|
||
return model |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.