diff --git a/.lintrunner.toml b/.lintrunner.toml index aa88d1f66..9b874e221 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -46,12 +46,11 @@ exclude_patterns = [ 'onnxscript/onnx_types.py', 'onnxscript/**/*_test.py', # Skip linting test files for speed 'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy - 'onnxscript/optimizer/evaluator.py', # FIXME - 'onnxscript/optimizer/constant_folding.py', # FIXME + 'onnxscript/optimizer/_legacy/evaluator.py', # FIXME + 'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME 'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME 'onnxscript/_legacy_ir/irbuilder.py', # FIXME - 'onnxscript/optimizer/fold_constants_v0.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME 'onnxscript/tools/function_unittest_producer.py', # FIXME 'onnxscript/_legacy_ir/visitor.py', # FIXME diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 985ac6f10..f30976c24 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -2,160 +2,22 @@ # Licensed under the MIT License. from __future__ import annotations -import logging -from typing import Any - import onnx -import onnx.shape_inference - -from onnxscript import ir, rewriter -from onnxscript.optimizer import _constant_folding, _inliner -from onnxscript.optimizer.constant_folding import fold_constants -from onnxscript.optimizer.remove_unused import remove_unused_nodes -from onnxscript.optimizer.remove_unused_function import remove_unused_functions -from onnxscript.optimizer.simple_function_folding import ( - inline_functions_with_unused_outputs, - inline_simple_functions, -) -from onnxscript.rewriter import ( - broadcast_to_matmul, - cast_constant_of_shape, - gemm_to_matmul_add, - no_op, -) - -logger = logging.getLogger(__name__) - -_DEFAULT_REWRITE_RULES = [ - *no_op.rules.rules, # TODO: merge this rule into constant folding? - *broadcast_to_matmul.rules.rules, - gemm_to_matmul_add.rule, - *cast_constant_of_shape.rules.rules, -] - - -def optimize( - model: onnx.ModelProto, - num_iterations: int = 2, - *, - onnx_shape_inference: bool = True, - stop_if_no_change: bool = True, - external_data_folder: str = "", - **kwargs: Any, -) -> onnx.ModelProto: - """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. - - Args: - model (onnx.ModelProto): The model to optimize. - num_iterations (int, optional): Number of iterations to perform. - onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. - Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. - This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries - the symbolic shapes recorded from dynamo tracing. - stop_if_no_change (bool, optional): Whether to stop if no change is detected. - external_data_folder (str, optional): The folder to store external data. - **kwargs: Additional keyword arguments. For BC purposes. - """ - if kwargs.pop("function_aware_folding", None) is not None: - logger.warning( - "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " - "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " - "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " - "See 'onnx_shape_inference' for more details." - ) - for _ in range(num_iterations): - if onnx_shape_inference: - if model.ByteSize() < 1024 * 1024 * 1024 * 2: - # NOTE: strict mode is disabled because it crashes on the models - # that have different shapes inferred from the model carried shapes. - # The case can be found in: - # https://github.com/microsoft/onnxscript/issues/1443 - model = onnx.shape_inference.infer_shapes( - model, check_type=True, strict_mode=False, data_prop=True - ) - else: - logger.warning( - "The model size is too large for full model shape inference. " - "Skipping this step." - ) - - inline_simple_functions(model) - modified = fold_constants( - model, external_data_folder, onnx_shape_inference=onnx_shape_inference - ) - - remove_unused_nodes(model) - inline_simple_functions(model) - model = remove_unused_functions(model) - inline_functions_with_unused_outputs(model) - # NOTE: This is general rewrite rules - model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) - if stop_if_no_change and not modified: - logger.debug("Stopping after %d iterations.", _) - break - - for node in model.graph.node: - logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) - - for function in model.functions: - for node in function.node: - logger.debug( - "Function %s::%s node %s::%s name %s.", - function.domain, - function.name, - node.domain, - node.op_type, - node.name, - ) - - return model - - -_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = ( - _constant_folding._DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT -) - -_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = ( - _constant_folding._DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT -) - -def optimize_ir( - model: ir.Model, - num_iterations: int = 2, - *, - onnx_shape_inference: bool = True, - stop_if_no_change: bool = True, - input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, - output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, -) -> None: - """Optimizes a model. +import onnxscript.optimizer._legacy._optimizer as legacy_optimizer +from onnxscript import ir +from onnxscript.optimizer._constant_folding import basic_constant_propagation +from onnxscript.optimizer._legacy.constant_folding import fold_constants +from onnxscript.optimizer._optimizer import optimize_ir +from onnxscript.optimizer._remove_unused import remove_unused_nodes - Args: - model: The model to be optimized. - num_iterations: Number of times the optimization loop is repeated. - onnx_shape_inference: Applies node-level shape-inference as part of optimization - input_size_limit: Will not apply constant folding to ops with any input of size - greater than this. Does not apply to special ops like Shape() and Size(). - output_size_limit: Will not rewrite any foldable-op into a Constant op if the size - of the output tensor is greater than this. - stop_if_no_change: Not supported currently (has no effect). Meant to stop the - outer optimization loop if no change is detected in one iteration. - """ - del stop_if_no_change # Looks like rewriter doesn't support this yet. - _inliner.inline(model) - for _ in range(num_iterations): - _constant_folding.fold_constants( - model, - onnx_shape_inference=onnx_shape_inference, - input_size_limit=input_size_limit, - output_size_limit=output_size_limit, - ) - rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) - remove_unused_nodes(model) +def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs): + if isinstance(model, ir.Model): + return optimize_ir(model, *args, **kwargs) + else: + return legacy_optimizer.optimize(model, *args, **kwargs) -basic_constant_propagation = _constant_folding.basic_constant_propagation __all__ = [ "fold_constants", diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 1144f207a..6a37efa16 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -17,20 +17,32 @@ import onnxscript.ir as ir import onnxscript.ir._convenience as _convenience -import onnxscript.optimizer.constant_folding as constant_folding import onnxscript.rewriter.pattern as orp import onnxscript.utils.utils as utils +DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024 + +DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024 + def is_control_flow_op(node: ir.Node) -> bool: graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS} return any(attr.type in graph_types for attr in node.attributes.values()) +non_deterministic_ops = frozenset( + { + "RandomUniform", + "RandomNormal", + "RandomUniformLike", + "RandomNormalLike", + "Multinomial", + } +) + + def is_non_deterministic_op(node: ir.Node) -> bool: - return node.op_type in constant_folding.non_deterministic_ops and utils.is_onnx_domain( - node.domain - ) + return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain) def is_onnx_op(node: ir.Node, op_type: str) -> bool: @@ -43,10 +55,6 @@ def is_constant_op(node: ir.Node) -> bool: ) -_DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024 - -_DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT - logger = logging.getLogger(__name__) # "Standard" evaluators are used to perform constant-folding. @@ -787,8 +795,8 @@ def fold_constants( external_data_folder: str = "", *, onnx_shape_inference: bool = False, - input_size_limit: int = _DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, - output_size_limit: int = _DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, + output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, ) -> bool: """ Applies constant folding optimization to the model. diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py similarity index 99% rename from onnxscript/optimizer/constant_folding_test.py rename to onnxscript/optimizer/_constant_folding_test.py index 7629653d4..b80f01c8f 100644 --- a/onnxscript/optimizer/constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -8,7 +8,8 @@ import onnxscript.optimizer as optimizer from onnxscript.ir import serde -from onnxscript.optimizer import _constant_folding, constant_folding +from onnxscript.optimizer import _constant_folding +from onnxscript.optimizer._legacy import constant_folding @parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) diff --git a/onnxscript/optimizer/function_folding_test.py b/onnxscript/optimizer/_function_folding_test.py similarity index 100% rename from onnxscript/optimizer/function_folding_test.py rename to onnxscript/optimizer/_function_folding_test.py diff --git a/onnxscript/optimizer/_legacy/_optimizer.py b/onnxscript/optimizer/_legacy/_optimizer.py new file mode 100644 index 000000000..f913bb465 --- /dev/null +++ b/onnxscript/optimizer/_legacy/_optimizer.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import logging +from typing import Any + +import onnx +import onnx.shape_inference + +from onnxscript import rewriter +from onnxscript.optimizer._legacy._simple_function_folding import ( + inline_functions_with_unused_outputs, + inline_simple_functions, +) +from onnxscript.optimizer._legacy.constant_folding import fold_constants +from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES +from onnxscript.optimizer._remove_unused import remove_unused_nodes +from onnxscript.optimizer._remove_unused_function import remove_unused_functions + +logger = logging.getLogger(__name__) + + +def optimize( + model: onnx.ModelProto, + num_iterations: int = 2, + *, + onnx_shape_inference: bool = True, + stop_if_no_change: bool = True, + external_data_folder: str = "", + **kwargs: Any, +) -> onnx.ModelProto: + """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. + + Args: + model (onnx.ModelProto): The model to optimize. + num_iterations (int, optional): Number of iterations to perform. + onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. + Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. + This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries + the symbolic shapes recorded from dynamo tracing. + stop_if_no_change (bool, optional): Whether to stop if no change is detected. + external_data_folder (str, optional): The folder to store external data. + **kwargs: Additional keyword arguments. For BC purposes. + """ + if kwargs.pop("function_aware_folding", None) is not None: + logger.warning( + "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " + "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " + "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " + "See 'onnx_shape_inference' for more details." + ) + for _ in range(num_iterations): + if onnx_shape_inference: + if model.ByteSize() < 1024 * 1024 * 1024 * 2: + # NOTE: strict mode is disabled because it crashes on the models + # that have different shapes inferred from the model carried shapes. + # The case can be found in: + # https://github.com/microsoft/onnxscript/issues/1443 + model = onnx.shape_inference.infer_shapes( + model, check_type=True, strict_mode=False, data_prop=True + ) + else: + logger.warning( + "The model size is too large for full model shape inference. " + "Skipping this step." + ) + + inline_simple_functions(model) + modified = fold_constants( + model, external_data_folder, onnx_shape_inference=onnx_shape_inference + ) + + remove_unused_nodes(model) + inline_simple_functions(model) + model = remove_unused_functions(model) + inline_functions_with_unused_outputs(model) + # NOTE: This is general rewrite rules + model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) + if stop_if_no_change and not modified: + logger.debug("Stopping after %d iterations.", _) + break + + for node in model.graph.node: + logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) + + for function in model.functions: + for node in function.node: + logger.debug( + "Function %s::%s node %s::%s name %s.", + function.domain, + function.name, + node.domain, + node.op_type, + node.name, + ) + + return model diff --git a/onnxscript/optimizer/remove_unused_proto.py b/onnxscript/optimizer/_legacy/_remove_unused_proto.py similarity index 100% rename from onnxscript/optimizer/remove_unused_proto.py rename to onnxscript/optimizer/_legacy/_remove_unused_proto.py diff --git a/onnxscript/optimizer/simple_function_folding.py b/onnxscript/optimizer/_legacy/_simple_function_folding.py similarity index 98% rename from onnxscript/optimizer/simple_function_folding.py rename to onnxscript/optimizer/_legacy/_simple_function_folding.py index 512bd104c..829bae9d6 100644 --- a/onnxscript/optimizer/simple_function_folding.py +++ b/onnxscript/optimizer/_legacy/_simple_function_folding.py @@ -11,7 +11,7 @@ import onnxscript._legacy_ir as ir from onnxscript._legacy_ir import visitor -from onnxscript.optimizer import remove_unused_proto +from onnxscript.optimizer._legacy import _remove_unused_proto logger = logging.getLogger(__name__) @@ -168,7 +168,7 @@ def _find_nodes_with_any_unused_output( # All unused output means the node is not used at all. # Hence do not update used_values with the node's inputs. continue - used_values |= remove_unused_proto.compute_used_in_node(node) + used_values |= _remove_unused_proto.compute_used_in_node(node) return target_nodes def visit_model(self, model: onnx.ModelProto) -> None: diff --git a/onnxscript/optimizer/simple_function_folding_test.py b/onnxscript/optimizer/_legacy/_simple_function_folding_test.py similarity index 84% rename from onnxscript/optimizer/simple_function_folding_test.py rename to onnxscript/optimizer/_legacy/_simple_function_folding_test.py index ffb987476..aa0af61a0 100644 --- a/onnxscript/optimizer/simple_function_folding_test.py +++ b/onnxscript/optimizer/_legacy/_simple_function_folding_test.py @@ -6,7 +6,8 @@ import onnx -from onnxscript.optimizer import remove_unused_function, simple_function_folding +from onnxscript.optimizer import _remove_unused_function +from onnxscript.optimizer._legacy import _simple_function_folding class SingleNodeFunctionFoldingTest(unittest.TestCase): @@ -32,8 +33,8 @@ def test_fold_single_node_function(self): """ ) - simple_function_folding.inline_simple_functions(model) - model = remove_unused_function.remove_unused_functions(model) + _simple_function_folding.inline_simple_functions(model) + model = _remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) @@ -59,8 +60,8 @@ def test_fold_single_node_function_ref_attr(self): """ ) - simple_function_folding.inline_simple_functions(model) - model = remove_unused_function.remove_unused_functions(model) + _simple_function_folding.inline_simple_functions(model) + model = _remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) self.assertFalse(model.graph.node[0].attribute[0].ref_attr_name) @@ -98,8 +99,8 @@ def test_fold_single_node_function_nested(self): """ ) - simple_function_folding.inline_simple_functions(model) - model = remove_unused_function.remove_unused_functions(model) + _simple_function_folding.inline_simple_functions(model) + model = _remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 1) self.assertEqual(model.functions[0].node[0].op_type, "Concat") @@ -127,8 +128,8 @@ def test_fold_single_node_function_create_new_nodes_with_correct_attributes(self } """ ) - simple_function_folding.inline_simple_functions(model) - model = remove_unused_function.remove_unused_functions(model) + _simple_function_folding.inline_simple_functions(model) + model = _remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) self.assertEqual(len(model.graph.node), 3) self.assertEqual(model.graph.node[0].attribute[0].i, 10) @@ -170,8 +171,8 @@ def test_fold_nested_if_function_succeeds(self): """ ) - simple_function_folding.inline_simple_functions(model) - model = remove_unused_function.remove_unused_functions(model) + _simple_function_folding.inline_simple_functions(model) + model = _remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) self.assertEqual(len(model.graph.node), 2) @@ -211,8 +212,8 @@ def test_fold_function_with_unused_output(self): """ ) - simple_function_folding.inline_functions_with_unused_outputs(model) - model = remove_unused_function.remove_unused_functions(model) + _simple_function_folding.inline_functions_with_unused_outputs(model) + model = _remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 1) diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/_legacy/constant_folding.py similarity index 96% rename from onnxscript/optimizer/constant_folding.py rename to onnxscript/optimizer/_legacy/constant_folding.py index d119c41e9..d30a8c9cc 100644 --- a/onnxscript/optimizer/constant_folding.py +++ b/onnxscript/optimizer/_legacy/constant_folding.py @@ -10,8 +10,9 @@ import onnx.reference.ops import onnxscript._legacy_ir as ir +import onnxscript.optimizer._constant_folding as _constant_folding from onnxscript._legacy_ir import visitor -from onnxscript.optimizer import evaluator +from onnxscript.optimizer._legacy import evaluator from onnxscript.utils.utils import ( is_control_flow_op, is_onnx_domain, @@ -19,26 +20,15 @@ logger = logging.getLogger(__name__) -_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 - # Ops excluded from constant-propagation: # * Random ops, which are not deterministic (checked below) # * Control flow ops (checked by presence of graph-attribute) -non_deterministic_ops = frozenset( - { - "RandomUniform", - "RandomNormal", - "RandomUniformLike", - "RandomNormalLike", - "Multinomial", - } -) - onnx_domain = frozenset({"", "onnx.ai"}) def is_non_deterministic_op(node: onnx.NodeProto) -> bool: + non_deterministic_ops = _constant_folding.non_deterministic_ops return node.op_type in non_deterministic_ops and is_onnx_domain(node.domain) @@ -89,7 +79,7 @@ def foldable_value(self, name: str, value): ) return None - if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: + if value.nbytes > _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT: logger.info( "Skip storing constant folded nvalue %s due to large size %s.", name, diff --git a/onnxscript/optimizer/evaluator.py b/onnxscript/optimizer/_legacy/evaluator.py similarity index 100% rename from onnxscript/optimizer/evaluator.py rename to onnxscript/optimizer/_legacy/evaluator.py diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py new file mode 100644 index 000000000..b5f4bcde0 --- /dev/null +++ b/onnxscript/optimizer/_optimizer.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import logging + +from onnxscript import ir, rewriter +from onnxscript.optimizer import _constant_folding, _inliner +from onnxscript.optimizer._remove_unused import remove_unused_nodes +from onnxscript.rewriter import ( + broadcast_to_matmul, + cast_constant_of_shape, + gemm_to_matmul_add, + no_op, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_REWRITE_RULES = [ + *no_op.rules.rules, # TODO: merge this rule into constant folding? + *broadcast_to_matmul.rules.rules, + gemm_to_matmul_add.rule, + *cast_constant_of_shape.rules.rules, +] + + +def optimize_ir( + model: ir.Model, + num_iterations: int = 2, + *, + onnx_shape_inference: bool = True, + stop_if_no_change: bool = True, + input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, + output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, +) -> None: + """Optimizes a model. + + Args: + model: The model to be optimized. + num_iterations: Number of times the optimization loop is repeated. + onnx_shape_inference: Applies node-level shape-inference as part of optimization + input_size_limit: Will not apply constant folding to ops with any input of size + greater than this. Does not apply to special ops like Shape() and Size(). + output_size_limit: Will not rewrite any foldable-op into a Constant op if the size + of the output tensor is greater than this. + stop_if_no_change: Not supported currently (has no effect). Meant to stop the + outer optimization loop if no change is detected in one iteration. + """ + del stop_if_no_change # Looks like rewriter doesn't support this yet. + _inliner.inline(model) + for _ in range(num_iterations): + _constant_folding.fold_constants( + model, + onnx_shape_inference=onnx_shape_inference, + input_size_limit=input_size_limit, + output_size_limit=output_size_limit, + ) + rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) + remove_unused_nodes(model) diff --git a/onnxscript/optimizer/optimizer_test.py b/onnxscript/optimizer/_optimizer_test.py similarity index 100% rename from onnxscript/optimizer/optimizer_test.py rename to onnxscript/optimizer/_optimizer_test.py diff --git a/onnxscript/optimizer/remove_unused_ir.py b/onnxscript/optimizer/_remove_unused.py similarity index 88% rename from onnxscript/optimizer/remove_unused_ir.py rename to onnxscript/optimizer/_remove_unused.py index 9fa73ca10..abd6f79b1 100644 --- a/onnxscript/optimizer/remove_unused_ir.py +++ b/onnxscript/optimizer/_remove_unused.py @@ -6,6 +6,7 @@ import onnx +import onnxscript.optimizer._legacy._remove_unused_proto from onnxscript import ir logger = logging.getLogger(__name__) @@ -81,8 +82,8 @@ def process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int: return count -def remove_unused_nodes(model: ir.Model) -> None: - """Removes unused nodes from the model.""" +def _remove_unused_nodes(model: ir.Model) -> None: + """Removes unused nodes from a model in IR form.""" count = process_function_or_graph(model.graph) graph_outputs = frozenset(model.graph.outputs) initializers = model.graph.initializers @@ -95,3 +96,11 @@ def remove_unused_nodes(model: ir.Model) -> None: count += process_function_or_graph(function) logger.info("Removed %s unused nodes", count) + + +def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: + """Removes unused nodes from a model.""" + if isinstance(model, ir.Model): + _remove_unused_nodes(model) + else: + onnxscript.optimizer._legacy._remove_unused_proto.remove_unused_nodes(model) diff --git a/onnxscript/optimizer/remove_unused_function.py b/onnxscript/optimizer/_remove_unused_function.py similarity index 100% rename from onnxscript/optimizer/remove_unused_function.py rename to onnxscript/optimizer/_remove_unused_function.py diff --git a/onnxscript/optimizer/remove_unused_test.py b/onnxscript/optimizer/_remove_unused_test.py similarity index 100% rename from onnxscript/optimizer/remove_unused_test.py rename to onnxscript/optimizer/_remove_unused_test.py diff --git a/onnxscript/optimizer/fold_constants_v0.py b/onnxscript/optimizer/fold_constants_v0.py deleted file mode 100644 index 9be7c9eda..000000000 --- a/onnxscript/optimizer/fold_constants_v0.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -from typing import Any, Sequence - -import numpy as np -import onnx -import onnx.reference.ops - -# Excluded ops include -# * Random ops, which are not deterministic -# * Control flow ops - -excluded_ops = frozenset( - { - "RandomUniform", - "RandomNormal", - "RandomUniformLike", - "RandomNormalLike", - "Multinomial", - "If", - "Loop", - "Scan", - "SequenceMap", - } -) - -onnx_domain = frozenset({"", "onnx.ai"}) - - -def get_evaluator(domain: str, op: str, version: int) -> callable | None: - if op in excluded_ops and domain in onnx_domain: - return None - try: - op_impl_class = onnx.reference.ops.load_op(domain, op, version) - except Exception: - return None - else: - return op_impl_class.eval - - -def convert_attributes(attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]: - return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} - - -def is_control_flow_op(node: onnx.NodeProto) -> bool: - return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute) - - -def is_constant_op(node: onnx.NodeProto) -> bool: - return node.op_type == "Constant" and node.domain == "" - - -def get_bool_value(val) -> bool | None: - if isinstance(val, bool): - return val - if isinstance(val, np.bool_): - return bool(val) - if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: - return val.item(0) - return None - - -def get_shape_info(type: onnx.TypeProto) -> tuple[int, ...] | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): - return np.array([d.dim_value for d in type.tensor_type.shape.dim], dtype=np.int64) - return None - - -def get_element_type(type: onnx.TypeProto) -> int | None: - if type.HasField("tensor_type"): - return type.tensor_type.elem_type - return None - - -class State: - def __init__(self, default_value) -> None: - self.scopes = [{}] - self.default_value = default_value - - def lookup(self, name: str) -> Any: - for scope in reversed(self.scopes): - if name in scope: - return scope[name] - return self.default_value - - def bind(self, name: str, value: Any) -> None: - self.scopes[-1][name] = value - - def enter_scope(self) -> None: - self.scopes.append({}) - - def exit_scope(self) -> None: - self.scopes.pop() - - -def is_onnx_op(node: onnx.NodeProto, op: str) -> bool: - return (node.op_type == op) and (node.domain in onnx_domain) - - -def matches(node: onnx.NodeProto, op: str, *arg_predicates) -> bool: - if node.op_type != op or node.domain != "": - return False - if len(node.input) < len(arg_predicates): - return False - return all(pred(input) for pred, input in zip(arg_predicates, node.input)) - - -def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto: - type = onnx.TypeProto() - type.tensor_type.elem_type = initializer.data_type - dims = type.tensor_type.shape.dim - for dim in initializer.dims: - dims.add().dim_value = dim - return type - - -def fold_constants(model: onnx.ModelProto): - not_constant = object() - var_info = State(default_value=not_constant) - type_info = State(default_value=None) - counts = {} - sizes = {} - - def add_count(op: str, size: int = 1): - counts[op] = counts.get(op, 0) + 1 - sizes[op] = sizes.get(op, 0) + size - - def new_constant(name, value): - var_info.bind(name, value) - tensor = onnx.numpy_helper.from_array(value, name=name) - node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor) - return node - - def lookup_version(domain: str, op: str) -> int: - for opset in model.opset_import: - if opset.domain == domain: - return opset.version - return 1 # TODO - - def transform_node(node: onnx.NodeProto): - if is_onnx_op(node, "Transpose"): - return [node] - if is_onnx_op(node, "CastLike"): - value = var_info.lookup(node.input[0]) if len(node.input) > 0 else not_constant - if value is not_constant: - return [node] - type = type_info.lookup(node.input[1]) if len(node.input) > 1 else None - element_type = get_element_type(type) if type is not None else None - if element_type is None: - return [node] - evaluator = get_evaluator("", "Cast", lookup_version("", "Cast")) - if evaluator is None: - return [node] - cast_value = evaluator(value, to=element_type) - add_count("CastLike", cast_value.size) - return [new_constant(node.output[0], cast_value)] - if is_onnx_op(node, "Shape"): - type = type_info.lookup(node.input[0]) if len(node.input) > 0 else None - shape = get_shape_info(type) if type is not None else None - if shape is not None: - add_count("Shape", shape.size) - return [new_constant(node.output[0], shape)] - - if is_onnx_op(node, "If"): - cond = var_info.lookup(node.input[0]) if len(node.input) > 0 else None - cond = get_bool_value(cond) - if cond is not None: - # cond is a constant-value: inline the branch - branch = "then_branch" if cond else "else_branch" - graph = onnx.helper.get_node_attr_value(node, branch) - formal_outs = list(graph.output) - actual_outs = node.output - renamings = { - formal.name: actual - for formal, actual in zip(formal_outs, actual_outs) - if actual != "" - } - - def rename(name): - return renamings.get(name, name) - - for node in graph.node: - node.input[:] = [rename(name) for name in node.input] - node.output[:] = [rename(name) for name in node.output] - transform_graph(graph) - add_count("If") - return list(graph.node) - - if is_control_flow_op(node): - for attr in node.attribute: - if attr.HasField("g"): - transform_graph(attr.g) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - transform_graph(graph) - return [node] - - domain = node.domain - op = node.op_type - version = lookup_version(domain, op) - inputs = [] - for x in node.input: - if x == "": - inputs.append(None) - else: - v = var_info.lookup(x) - if v is not_constant: - return [node] - inputs.append(v) - evaluator = get_evaluator(domain, op, version) - if evaluator is None: - return [node] - attrs = convert_attributes(node.attribute) - outputs = evaluator(*inputs, **attrs) - if len(node.output) == 1 and not isinstance(outputs, tuple): - replacement = new_constant(node.output[0], outputs) - if is_constant_op(node): - return [node] - add_count(op, outputs.size) - return [replacement] - else: - add_count(op) - return [new_constant(output, outputs[i]) for i, output in enumerate(node.output)] - - def transform_graph(graph: onnx.GraphProto): - var_info.enter_scope() - type_info.enter_scope() - for initializer in graph.initializer: - array = onnx.numpy_helper.to_array(initializer) - var_info.bind(initializer.name, array) - type_info.bind(initializer.name, get_initializer_type(initializer)) - for input in graph.input: - var_info.bind(input.name, not_constant) - type_info.bind(input.name, input.type) - for valueinfo in graph.value_info: - type_info.bind(valueinfo.name, valueinfo.type) - - replacement = [transform_node(node) for node in graph.node] - flattened = [node for nodes in replacement for node in nodes] - del graph.node[:] - graph.node.extend(flattened) - var_info.exit_scope() - type_info.exit_scope() - - transform_graph(model.graph) - for op in counts: - print(f"Constant-folded '{op}' {counts[op]} times, with {sizes[op]} size.") diff --git a/onnxscript/optimizer/remove_unused.py b/onnxscript/optimizer/remove_unused.py deleted file mode 100644 index 567362d60..000000000 --- a/onnxscript/optimizer/remove_unused.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import onnx - -import onnxscript.optimizer.remove_unused_ir -import onnxscript.optimizer.remove_unused_proto -from onnxscript import ir - - -def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: - if isinstance(model, ir.Model): - onnxscript.optimizer.remove_unused_ir.remove_unused_nodes(model) - else: - onnxscript.optimizer.remove_unused_proto.remove_unused_nodes(model) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index e6d1e85ff..421535553 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -15,7 +15,7 @@ import onnx from onnxscript import ir -from onnxscript.optimizer import remove_unused, remove_unused_function +from onnxscript.optimizer import _remove_unused, _remove_unused_function from onnxscript.rewriter import function_rule, pattern RewriteRuleSet = pattern.RewriteRuleSet @@ -48,8 +48,8 @@ def rewrite( count = pattern_rewrite_rules.apply_to_model(model_ir) if count: print(f"Applied {count} of general pattern rewrite rules.") - remove_unused.remove_unused_nodes(model_ir) - model_ir = remove_unused_function.remove_unused_functions(model_ir) + _remove_unused.remove_unused_nodes(model_ir) + model_ir = _remove_unused_function.remove_unused_functions(model_ir) if proto: model = ir.serde.serialize_model(model_ir) return model diff --git a/onnxscript/tools/benchmark/benchmark_helpers.py b/onnxscript/tools/benchmark/benchmark_helpers.py index 3a874fa46..08951b39e 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers.py +++ b/onnxscript/tools/benchmark/benchmark_helpers.py @@ -25,7 +25,7 @@ import onnxscript.rewriter.onnxruntime as ort_rules import onnxscript.rewriter.pattern as orp from onnxscript import ir -from onnxscript.optimizer.remove_unused import remove_unused_nodes +from onnxscript.optimizer._remove_unused import remove_unused_nodes def get_parsed_args(