Skip to content

[pass] Add CSE #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
313c756
add cse
titaiwangms May 30, 2025
3b11aac
lint and add test dependency
titaiwangms May 30, 2025
f5f086b
add onnxscript
titaiwangms May 30, 2025
e673d5b
add --no-deps
titaiwangms May 30, 2025
a047bef
fix no-deps
titaiwangms May 30, 2025
b25a41b
fix test import
titaiwangms May 30, 2025
e0917d8
Update noxfile.py
justinchuby May 30, 2025
f9a1c1f
Update src/onnx_ir/passes/common/common_subexpression_elimination_tes…
titaiwangms May 30, 2025
0de5ce2
Update src/onnx_ir/passes/common/common_subexpression_elimination.py
titaiwangms May 30, 2025
6a5746e
skip non deterministic ops
titaiwangms Jun 2, 2025
61fb5f2
Rename reuse job to make required (#37)
justinchuby May 31, 2025
9dc05ce
Configure the do-not-merge bot (#40)
justinchuby May 31, 2025
5bd0fa0
Remove RefAttr from docs because it is not a class (#35)
justinchuby Jun 1, 2025
19cdd00
Revert changes to LiftConstantsToInitializersPass (#41)
justinchuby Jun 1, 2025
383bcba
Bump version number to v0.1.1 (#42)
justinchuby Jun 1, 2025
8fe5d5d
Bump ruff from 0.11.11 to 0.11.12 in /requirements/lintrunner (#47)
dependabot[bot] Jun 2, 2025
11beaee
Bump mypy from 1.15.0 to 1.16.0 in /requirements/lintrunner (#46)
dependabot[bot] Jun 2, 2025
750b691
Bump ossf/scorecard-action from 2.4.1 to 2.4.2 (#48)
dependabot[bot] Jun 2, 2025
3cef183
Prevent values produced by other nodes to be added as graph inputs (#51)
justinchuby Jun 2, 2025
53468eb
Create release.yml for generating release notes (#52)
justinchuby Jun 2, 2025
220d1d8
Create `add` methods on initializers and attributes (#33)
justinchuby Jun 2, 2025
174742c
Update codecov.yml to ignore unit tests (#53)
justinchuby Jun 2, 2025
1823b62
Merge branch 'main' into titaiwang/add_cse
justinchuby Jun 2, 2025
3002e3b
Update src/onnx_ir/passes/common/common_subexpression_elimination.py
justinchuby Jun 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
"types-PyYAML",
"typing_extensions>=4.10",
"ml-dtypes",
"onnxruntime",
)
ONNX = "onnx==1.18"
ONNXSCRIPT = "onnxscript"
ONNX_RUNTIME = "onnxruntime==1.20.1"
PYTORCH = "torch==2.7.0"
TORCHVISON = "torchvision==0.22.0"
Expand All @@ -50,6 +52,7 @@ def test(session):
ONNX,
PYTORCH,
)
session.install(ONNXSCRIPT, "--no-deps")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run("pytest", "src", "--doctest-modules", *session.posargs)
Expand All @@ -61,6 +64,7 @@ def test_onnx_weekly(session):
"""Test with ONNX weekly (preview) build."""
session.install(*COMMON_TEST_DEPENDENCIES, PYTORCH)
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install(ONNXSCRIPT, "--no-deps")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run("pytest", "src", "--doctest-modules", *session.posargs)
Expand All @@ -73,6 +77,7 @@ def test_torch_nightly(session):
session.install(*COMMON_TEST_DEPENDENCIES)
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install("-r", "requirements/ci/requirements-pytorch-nightly.txt")
session.install(ONNXSCRIPT, "--no-deps")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run("pytest", "src", "--doctest-modules", *session.posargs)
Expand Down
4 changes: 4 additions & 0 deletions src/onnx_ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"AddInitializersToInputsPass",
"CheckerPass",
"ClearMetadataAndDocStringPass",
"CommonSubexpressionEliminationPass",
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
Expand All @@ -19,6 +20,9 @@
from onnx_ir.passes.common.clear_metadata_and_docstring import (
ClearMetadataAndDocStringPass,
)
from onnx_ir.passes.common.common_subexpression_elimination import (
CommonSubexpressionEliminationPass,
)
from onnx_ir.passes.common.constant_manipulation import (
AddInitializersToInputsPass,
LiftConstantsToInitializersPass,
Expand Down
177 changes: 177 additions & 0 deletions src/onnx_ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Eliminate common subexpression in ONNX graphs."""

from __future__ import annotations

__all__ = [
"CommonSubexpressionEliminationPass",
]

import logging
from collections.abc import Sequence

import onnx_ir as ir

logger = logging.getLogger(__name__)


class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
"""Eliminate common subexpression in ONNX graphs."""

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Return the same ir.Model but with CSE applied to the graph."""
modified = False
graph = model.graph

modified = _eliminate_common_subexpression(graph, modified)

return ir.passes.PassResult(
model,
modified=modified,
)


def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
"""Eliminate common subexpression in ONNX graphs."""
# node to node identifier, length of outputs, inputs, and attributes
existing_node_info_to_the_node: dict[
tuple[
ir.OperatorIdentifier,
int, # len(outputs)
tuple[int, ...], # input ids
tuple[tuple[str, object], ...], # attributes
],
ir.Node,
] = {}

for node in graph:
# Skip control flow ops like Loop and If.
control_flow_op: bool = False
# Use equality to check if the node is a common subexpression.
attributes = {}
for k, v in node.attributes.items():
# TODO(exporter team): CSE subgraphs.
# NOTE: control flow ops like Loop and If won't be CSEd
# because attribute: graph won't match.
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
control_flow_op = True
logger.debug("Skipping control flow op %s", node)
# The attribute value could be directly taken from the original
# protobuf, so we need to make a copy of it.
value = v.value
if v.type in (
ir.AttributeType.INTS,
ir.AttributeType.FLOATS,
ir.AttributeType.STRINGS,
):
# For INT, FLOAT and STRING attributes, we convert them to tuples
# to ensure they are hashable.
value = tuple(value)
attributes[k] = value

if control_flow_op:
# If the node is a control flow op, we skip it.
logger.debug("Skipping control flow op %s", node)
continue

if _is_non_deterministic_op(node):
# If the node is a non-deterministic op, we skip it.
logger.debug("Skipping non-deterministic op %s", node)
continue

node_info = (
node.op_identifier(),
len(node.outputs),
tuple(id(input) for input in node.inputs),
tuple(sorted(attributes.items())),
)
# Check if the node is a common subexpression.
if node_info in existing_node_info_to_the_node:
# If it is, this node has an existing node with the same
# operator, number of outputs, inputs, and attributes.
# We replace the node with the existing node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
_remove_node_and_replace_values(
graph,
remove_node=node,
remove_values=node.outputs,
new_values=existing_node.outputs,
)
logger.debug("Reusing node %s", existing_node)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
return modified


def _remove_node_and_replace_values(
graph: ir.Graph,
/,
remove_node: ir.Node,
remove_values: Sequence[ir.Value],
new_values: Sequence[ir.Value],
) -> None:
"""Replaces nodes and values in the graph or function.

Args:
graph: The graph to replace nodes and values in.
remove_node: The node to remove.
remove_values: The values to replace.
new_values: The values to replace with.
"""
# Reconnect the users of the deleted values to use the new values
ir.convenience.replace_all_uses_with(remove_values, new_values)
# Update graph/function outputs if the node generates output
if any(remove_value.is_graph_output() for remove_value in remove_values):
replacement_mapping = dict(zip(remove_values, new_values))
for idx, graph_output in enumerate(graph.outputs):
if graph_output in replacement_mapping:
new_value = replacement_mapping[graph_output]
if new_value.is_graph_output() or new_value.is_graph_input():
# If the new value is also a graph input/output, we need to
# create a Identity node to preserve the remove_value and
# prevent from changing new_value name.
identity_node = ir.node(
"Identity",
inputs=[new_value],
outputs=[
ir.Value(
name=graph_output.name,
type=graph_output.type,
shape=graph_output.shape,
)
],
)
# reuse the name of the graph output
graph.outputs[idx] = identity_node.outputs[0]
graph.insert_before(
remove_node,
identity_node,
)
else:
# if new_value is not graph output, we just
# update it to use old_value name.
new_value.name = graph_output.name
graph.outputs[idx] = new_value

graph.remove(remove_node, safe=True)


def _is_non_deterministic_op(node: ir.Node) -> bool:
non_deterministic_ops = frozenset(
{
"RandomUniform",
"RandomNormal",
"RandomUniformLike",
"RandomNormalLike",
"Multinomial",
}
)
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)


def _is_onnx_domain(d: str) -> bool:
"""Check if the domain is the ONNX domain."""
return d == ""
Loading