Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup optimizer #1904

Merged
merged 7 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ exclude_patterns = [
'onnxscript/onnx_types.py',
'onnxscript/**/*_test.py', # Skip linting test files for speed
'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy
'onnxscript/optimizer/evaluator.py', # FIXME
'onnxscript/optimizer/constant_folding.py', # FIXME
'onnxscript/optimizer/legacy/evaluator.py', # FIXME
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
'onnxscript/optimizer/legacy/constant_folding.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/optimizer/fold_constants_v0.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
'onnxscript/_legacy_ir/visitor.py', # FIXME
Expand Down
160 changes: 11 additions & 149 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,160 +2,22 @@
# Licensed under the MIT License.
from __future__ import annotations

import logging
from typing import Any

import onnx
import onnx.shape_inference

from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
from onnxscript.optimizer.constant_folding import fold_constants
from onnxscript.optimizer.remove_unused import remove_unused_nodes
from onnxscript.optimizer.remove_unused_function import remove_unused_functions
from onnxscript.optimizer.simple_function_folding import (
inline_functions_with_unused_outputs,
inline_simple_functions,
)
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
gemm_to_matmul_add,
no_op,
)

logger = logging.getLogger(__name__)

_DEFAULT_REWRITE_RULES = [
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule,
*cast_constant_of_shape.rules.rules,
]


def optimize(
model: onnx.ModelProto,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
external_data_folder: str = "",
**kwargs: Any,
) -> onnx.ModelProto:
"""Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc.

Args:
model (onnx.ModelProto): The model to optimize.
num_iterations (int, optional): Number of iterations to perform.
onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model.
Set this to False to turn off onnx shape inference, and rely on model carried shapes and types.
This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries
the symbolic shapes recorded from dynamo tracing.
stop_if_no_change (bool, optional): Whether to stop if no change is detected.
external_data_folder (str, optional): The folder to store external data.
**kwargs: Additional keyword arguments. For BC purposes.
"""
if kwargs.pop("function_aware_folding", None) is not None:
logger.warning(
"'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. "
"To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. "
"This would turn off incremental onnx shape inference and rely on model carried shapes and types. "
"See 'onnx_shape_inference' for more details."
)
for _ in range(num_iterations):
if onnx_shape_inference:
if model.ByteSize() < 1024 * 1024 * 1024 * 2:
# NOTE: strict mode is disabled because it crashes on the models
# that have different shapes inferred from the model carried shapes.
# The case can be found in:
# https://github.com/microsoft/onnxscript/issues/1443
model = onnx.shape_inference.infer_shapes(
model, check_type=True, strict_mode=False, data_prop=True
)
else:
logger.warning(
"The model size is too large for full model shape inference. "
"Skipping this step."
)

inline_simple_functions(model)
modified = fold_constants(
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
)

remove_unused_nodes(model)
inline_simple_functions(model)
model = remove_unused_functions(model)
inline_functions_with_unused_outputs(model)
# NOTE: This is general rewrite rules
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
if stop_if_no_change and not modified:
logger.debug("Stopping after %d iterations.", _)
break

for node in model.graph.node:
logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name)

for function in model.functions:
for node in function.node:
logger.debug(
"Function %s::%s node %s::%s name %s.",
function.domain,
function.name,
node.domain,
node.op_type,
node.name,
)

return model


_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = (
_constant_folding._DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT
)

_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = (
_constant_folding._DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT
)


def optimize_ir(
model: ir.Model,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
) -> None:
"""Optimizes a model.
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
from onnxscript import ir
from onnxscript.optimizer._constant_folding import basic_constant_propagation
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import optimize_ir
from onnxscript.optimizer._remove_unused import remove_unused_nodes

Args:
model: The model to be optimized.
num_iterations: Number of times the optimization loop is repeated.
onnx_shape_inference: Applies node-level shape-inference as part of optimization
input_size_limit: Will not apply constant folding to ops with any input of size
greater than this. Does not apply to special ops like Shape() and Size().
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
of the output tensor is greater than this.
stop_if_no_change: Not supported currently (has no effect). Meant to stop the
outer optimization loop if no change is detected in one iteration.
"""
del stop_if_no_change # Looks like rewriter doesn't support this yet.
_inliner.inline(model)
for _ in range(num_iterations):
_constant_folding.fold_constants(
model,
onnx_shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
)
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
remove_unused_nodes(model)

def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
return optimize_ir(model, *args, **kwargs)
Dismissed Show dismissed Hide dismissed
else:
return legacy_optimizer.optimize(model, *args, **kwargs)

basic_constant_propagation = _constant_folding.basic_constant_propagation

__all__ = [
"fold_constants",
Expand Down
28 changes: 18 additions & 10 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,32 @@

import onnxscript.ir as ir
import onnxscript.ir._convenience as _convenience
import onnxscript.optimizer.constant_folding as constant_folding
import onnxscript.rewriter.pattern as orp
import onnxscript.utils.utils as utils

DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024

DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024


def is_control_flow_op(node: ir.Node) -> bool:
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
return any(attr.type in graph_types for attr in node.attributes.values())


non_deterministic_ops = frozenset(
{
"RandomUniform",
"RandomNormal",
"RandomUniformLike",
"RandomNormalLike",
"Multinomial",
}
)


def is_non_deterministic_op(node: ir.Node) -> bool:
return node.op_type in constant_folding.non_deterministic_ops and utils.is_onnx_domain(
node.domain
)
return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain)


def is_onnx_op(node: ir.Node, op_type: str) -> bool:
Expand All @@ -43,10 +55,6 @@ def is_constant_op(node: ir.Node) -> bool:
)


_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024

_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT

logger = logging.getLogger(__name__)

# "Standard" evaluators are used to perform constant-folding.
Expand Down Expand Up @@ -787,8 +795,8 @@ def fold_constants(
external_data_folder: str = "",
*,
onnx_shape_inference: bool = False,
input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
) -> bool:
"""
Applies constant folding optimization to the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import onnxscript.optimizer as optimizer
from onnxscript.ir import serde
from onnxscript.optimizer import _constant_folding, constant_folding
from onnxscript.optimizer import _constant_folding
from onnxscript.optimizer._legacy import constant_folding


@parameterized.parameterized_class(("using_ir",), [(False,), (True,)])
Expand Down
98 changes: 98 additions & 0 deletions onnxscript/optimizer/_legacy/_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations
Fixed Show fixed Hide fixed

import logging
from typing import Any

import onnx
import onnx.shape_inference

from onnxscript import rewriter
from onnxscript.optimizer._legacy._simple_function_folding import (
inline_functions_with_unused_outputs,
inline_simple_functions,
)
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES
from onnxscript.optimizer._remove_unused import remove_unused_nodes
from onnxscript.optimizer._remove_unused_function import remove_unused_functions

logger = logging.getLogger(__name__)


def optimize(
model: onnx.ModelProto,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
external_data_folder: str = "",
**kwargs: Any,
) -> onnx.ModelProto:
"""Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc.

Args:
model (onnx.ModelProto): The model to optimize.
num_iterations (int, optional): Number of iterations to perform.
onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model.
Set this to False to turn off onnx shape inference, and rely on model carried shapes and types.
This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries
the symbolic shapes recorded from dynamo tracing.
stop_if_no_change (bool, optional): Whether to stop if no change is detected.
external_data_folder (str, optional): The folder to store external data.
**kwargs: Additional keyword arguments. For BC purposes.
"""
if kwargs.pop("function_aware_folding", None) is not None:
logger.warning(
"'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. "
"To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. "
"This would turn off incremental onnx shape inference and rely on model carried shapes and types. "
"See 'onnx_shape_inference' for more details."
)
for _ in range(num_iterations):
if onnx_shape_inference:
if model.ByteSize() < 1024 * 1024 * 1024 * 2:
# NOTE: strict mode is disabled because it crashes on the models
# that have different shapes inferred from the model carried shapes.
# The case can be found in:
# https://github.com/microsoft/onnxscript/issues/1443
model = onnx.shape_inference.infer_shapes(
model, check_type=True, strict_mode=False, data_prop=True
)
else:
logger.warning(
"The model size is too large for full model shape inference. "
"Skipping this step."
)

inline_simple_functions(model)
modified = fold_constants(
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
)

remove_unused_nodes(model)
inline_simple_functions(model)
model = remove_unused_functions(model)
inline_functions_with_unused_outputs(model)
# NOTE: This is general rewrite rules
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
if stop_if_no_change and not modified:
logger.debug("Stopping after %d iterations.", _)
break

for node in model.graph.node:
logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name)

for function in model.functions:
for node in function.node:
logger.debug(
"Function %s::%s node %s::%s name %s.",
function.domain,
function.name,
node.domain,
node.op_type,
node.name,
)

return model
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import onnxscript._legacy_ir as ir
from onnxscript._legacy_ir import visitor
from onnxscript.optimizer import remove_unused_proto
from onnxscript.optimizer._legacy import _remove_unused_proto

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -168,7 +168,7 @@ def _find_nodes_with_any_unused_output(
# All unused output means the node is not used at all.
# Hence do not update used_values with the node's inputs.
continue
used_values |= remove_unused_proto.compute_used_in_node(node)
used_values |= _remove_unused_proto.compute_used_in_node(node)
return target_nodes

def visit_model(self, model: onnx.ModelProto) -> None:
Expand Down
Loading
Loading