33from __future__ import annotations
44
55import logging
6+ from typing import Callable
67
78import onnx_ir as ir
89import onnx_ir .passes .common as common_passes
@@ -21,6 +22,7 @@ def optimize_ir(
2122 stop_if_no_change : bool = True ,
2223 input_size_limit : int = _constant_folding .DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT ,
2324 output_size_limit : int = _constant_folding .DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT ,
25+ should_fold : Callable [[ir .Node ], bool | None ] = lambda node : None ,
2426 inline : bool = True ,
2527) -> None :
2628 """Optimizes a model.
@@ -29,11 +31,15 @@ def optimize_ir(
2931 model: The model to be optimized.
3032 num_iterations: Number of times the optimization loop is repeated.
3133 onnx_shape_inference: Applies node-level shape-inference as part of optimization
34+ stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
3235 input_size_limit: Will not apply constant folding to ops with any input of size
3336 greater than this. Does not apply to special ops like Shape() and Size().
3437 output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
3538 of the output tensor is greater than this.
36- stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
39+ should_fold: An optional function that takes a node and returns True if
40+ the node should be considered for folding.
41+ The function should return True/False value to indicate if this particular
42+ node should be folded, or None to use the default folding rules.
3743 inline: If True, inlines all functions in the model.
3844 """
3945 passes = [
@@ -43,6 +49,7 @@ def optimize_ir(
4349 shape_inference = onnx_shape_inference ,
4450 input_size_limit = input_size_limit ,
4551 output_size_limit = output_size_limit ,
52+ should_fold = should_fold ,
4653 ),
4754 rewriter .RewritePass (rewriter ._DEFAULT_REWRITE_RULES ),
4855 common_passes .RemoveUnusedNodesPass (),
0 commit comments