Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/api/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@
optimizer.inline
optimizer.basic_constant_propagation
optimizer.fold_constants
optimizer.remove_unused_nodes
```
8 changes: 2 additions & 6 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@

import onnxscript.optimizer._constant_folding as constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import (
basic_constant_propagation,
)
from onnxscript.optimizer._constant_folding import (
fold_constants as fold_constants_ir,
)
from onnxscript.optimizer._constant_folding import basic_constant_propagation
from onnxscript.optimizer._constant_folding import fold_constants as fold_constants_ir
from onnxscript.optimizer._optimizer import optimize_ir

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
Expand Down
35 changes: 34 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@

from __future__ import annotations

__all__ = [
"basic_constant_propagation",
"fold_constants",
"FoldConstantsPass",
"FOLDED_FROM_KEY",
]

import dataclasses
import logging
import math
Expand All @@ -23,6 +30,9 @@

DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512

# Key used to store the metadata
FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from"


_NON_DETERMINISTIC_OPS = frozenset(
{
Expand Down Expand Up @@ -914,6 +924,24 @@ def merge_dims(dim1, dim2):
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])


def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None:
"""Record the set of original input values that contributed to the constant-folded outputs."""
folded_from: set[str] = set()
for input in original_node.inputs:
if input is None:
continue
folded_from.update(input.meta.get(FOLDED_FROM_KEY, set()))
assert input.name is not None
folded_from.add(input.name)

for new_output in replacement.new_outputs:
if new_output is None:
continue
new_output.meta[FOLDED_FROM_KEY] = folded_from
# Store the string representation of the set to metadata_props to persist it across serialization
new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from))


class FoldConstantsPass(ir.passes.InPlacePass):
"""A pass that folds constant expressions in the model.

Expand Down Expand Up @@ -1203,9 +1231,14 @@ def convert(av):
)
return None

def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
def replace_node(
self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function
) -> None:
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)

# Record the names of the values that has contributed to the replacement
_record_contributing_values(node, replacement)

ir.convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)
Expand Down
Loading