Skip to content

Turn constant folder and dce into passes #2109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,7 @@ def merge_dims(dim1, dim2):
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])


class ConstantFolder:
opset_imports: dict[str, int]

class FoldConstantsPass(ir.passes.PassBase):
def __init__(
self,
*,
Expand All @@ -812,11 +810,17 @@ def __init__(
self._shape_inference = shape_inference
self._input_size_limit = input_size_limit
self._output_size_limit = output_size_limit
self._init()

def _init(self) -> None:
self.opset_imports: dict[str, int] = {}
self.counts: dict[str, int] = {}
self.sizes: dict[str, int] = {}
self.modified: bool = False
self._state = OptimizerState()
self._reset()

def _reset(self) -> None:
"""Reset internal states for a new run."""
self.counts = {}
self.sizes = {}
self.modified = False
self._state = OptimizerState()

Expand Down Expand Up @@ -931,6 +935,7 @@ def process_node(self, node: ir.Node):
sym_value.name,
)
node.replace_input_with(i, sym_value)
self.modified = True
# TODO(rama): consider merging type/other info from both values

# Do incremental shape inference
Expand Down Expand Up @@ -1007,6 +1012,8 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)

self.modified = True

# TODO: what about new opset_imports?
# TODO: track statistics about replaced nodes and sizes of new constants

Expand Down Expand Up @@ -1045,13 +1052,14 @@ def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)

def visit_model(self, model: ir.Model) -> None:
self._init()
def call(self, model: ir.Model) -> ir.passes.PassResult:
self._reset()
self.opset_imports = model.opset_imports
self.visit_graph(model.graph)
for function in model.functions.values():
# TODO(rama): Should we specialize functions?
self.visit_function(function)
return ir.passes.PassResult(model, self.modified)


def fold_constants(
Expand All @@ -1066,18 +1074,18 @@ def fold_constants(
Applies constant folding optimization to the model.
Returns true iff the model was modified.
"""
folder = ConstantFolder(
folder_pass = FoldConstantsPass(
external_data_folder=external_data_folder,
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
)
folder.visit_model(model)
for op in folder.counts:
folder_pass(model)
for op in folder_pass.counts:
logger.info(
"Constant-folded '%s' %s times, with %s size.",
op,
folder.counts[op],
folder.sizes[op],
folder_pass.counts[op],
folder_pass.sizes[op],
)
return folder.modified
return folder_pass.modified
38 changes: 20 additions & 18 deletions onnxscript/optimizer/_remove_unused.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
out.name = ""


def process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int:
def _process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int:
graph_outputs = frozenset(function_or_graph.outputs)
onnx_opset_version = function_or_graph.opset_imports.get("", None)
count = 0
Expand All @@ -75,32 +75,34 @@
if not isinstance(attr, ir.Attr):
continue
if attr.type == ir.AttributeType.GRAPH:
count += process_function_or_graph(attr.as_graph())
count += _process_function_or_graph(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.as_graphs():
count += process_function_or_graph(graph)
count += _process_function_or_graph(graph)

Check warning on line 81 in onnxscript/optimizer/_remove_unused.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_remove_unused.py#L81

Added line #L81 was not covered by tests
return count


def _remove_unused_nodes(model: ir.Model) -> None:
"""Removes unused nodes from a model in IR form."""
count = process_function_or_graph(model.graph)
graph_outputs = frozenset(model.graph.outputs)
initializers = model.graph.initializers
for init in list(initializers.values()):
if not (init in graph_outputs or init.uses()):
del initializers[init.name] # type: ignore[arg-type]
count += 1

for function in model.functions.values():
count += process_function_or_graph(function)

logger.info("Removed %s unused nodes", count)
class RemoveUnusedNodesPass(ir.passes.PassBase):
def call(self, model: ir.Model) -> ir.passes.PassResult:
count = _process_function_or_graph(model.graph)
graph_outputs = frozenset(model.graph.outputs)
initializers = model.graph.initializers
for init in list(initializers.values()):
if not (init in graph_outputs or init.uses()):
assert init.name is not None
del initializers[init.name]
count += 1
for function in model.functions.values():
count += _process_function_or_graph(function)
if count:
logger.info("Removed %s unused nodes", count)
return ir.passes.PassResult(model, modified=True)
return ir.passes.PassResult(model, modified=False)


def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused nodes from a model."""
if isinstance(model, ir.Model):
_remove_unused_nodes(model)
RemoveUnusedNodesPass()(model)
else:
onnxscript.optimizer._legacy._remove_unused_proto.remove_unused_nodes(model)
Loading