-
Notifications
You must be signed in to change notification settings - Fork 9
[pass] Add CSE #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[pass] Add CSE #36
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
313c756
add cse
titaiwangms 3b11aac
lint and add test dependency
titaiwangms f5f086b
add onnxscript
titaiwangms e673d5b
add --no-deps
titaiwangms a047bef
fix no-deps
titaiwangms b25a41b
fix test import
titaiwangms e0917d8
Update noxfile.py
justinchuby f9a1c1f
Update src/onnx_ir/passes/common/common_subexpression_elimination_tes…
titaiwangms 0de5ce2
Update src/onnx_ir/passes/common/common_subexpression_elimination.py
titaiwangms 6a5746e
skip non deterministic ops
titaiwangms 61fb5f2
Rename reuse job to make required (#37)
justinchuby 9dc05ce
Configure the do-not-merge bot (#40)
justinchuby 5bd0fa0
Remove RefAttr from docs because it is not a class (#35)
justinchuby 19cdd00
Revert changes to LiftConstantsToInitializersPass (#41)
justinchuby 383bcba
Bump version number to v0.1.1 (#42)
justinchuby 8fe5d5d
Bump ruff from 0.11.11 to 0.11.12 in /requirements/lintrunner (#47)
dependabot[bot] 11beaee
Bump mypy from 1.15.0 to 1.16.0 in /requirements/lintrunner (#46)
dependabot[bot] 750b691
Bump ossf/scorecard-action from 2.4.1 to 2.4.2 (#48)
dependabot[bot] 3cef183
Prevent values produced by other nodes to be added as graph inputs (#51)
justinchuby 53468eb
Create release.yml for generating release notes (#52)
justinchuby 220d1d8
Create `add` methods on initializers and attributes (#33)
justinchuby 174742c
Update codecov.yml to ignore unit tests (#53)
justinchuby 1823b62
Merge branch 'main' into titaiwang/add_cse
justinchuby 3002e3b
Update src/onnx_ir/passes/common/common_subexpression_elimination.py
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
177 changes: 177 additions & 0 deletions
177
src/onnx_ir/passes/common/common_subexpression_elimination.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Copyright (c) ONNX Project Contributors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Eliminate common subexpression in ONNX graphs.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"CommonSubexpressionEliminationPass", | ||
] | ||
|
||
import logging | ||
|
||
from collections.abc import Sequence | ||
|
||
import onnx_ir as ir | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass): | ||
"""Eliminate common subexpression in ONNX graphs.""" | ||
|
||
def call(self, model: ir.Model) -> ir.passes.PassResult: | ||
"""Return the same ir.Model but with CSE applied to the graph.""" | ||
modified = False | ||
graph = model.graph | ||
|
||
modified = _eliminate_common_subexpression(graph, modified) | ||
|
||
return ir.passes.PassResult( | ||
model, | ||
modified=modified, | ||
) | ||
|
||
|
||
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool: | ||
"""Eliminate common subexpression in ONNX graphs.""" | ||
|
||
# node to node identifier, length of outputs, inputs, and attributes | ||
existing_node_info_to_the_node: dict[ | ||
tuple[ | ||
ir.OperatorIdentifier, | ||
int, # len(outputs) | ||
tuple[int, ...], # input ids | ||
tuple[tuple[str, object], ...], # attributes | ||
], | ||
ir.Node, | ||
] = {} | ||
|
||
for node in graph: | ||
# Skip control flow ops like Loop and If. | ||
control_flow_op: bool = False | ||
# Use equality to check if the node is a common subexpression. | ||
attributes = {} | ||
for k, v in node.attributes.items(): | ||
# TODO(exporter team): CSE subgraphs. | ||
# NOTE: control flow ops like Loop and If won't be CSEd | ||
# because attribute: graph won't match. | ||
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): | ||
control_flow_op = True | ||
logger.debug("Skipping control flow op %s", node) | ||
# The attribute value could be directly taken from the original | ||
# protobuf, so we need to make a copy of it. | ||
value = v.value | ||
if v.type in ( | ||
ir.AttributeType.INTS, | ||
ir.AttributeType.FLOATS, | ||
ir.AttributeType.STRINGS, | ||
): | ||
# For INT, FLOAT and STRING attributes, we convert them to tuples | ||
# to ensure they are hashable. | ||
value = tuple(value) | ||
attributes[k] = value | ||
|
||
if control_flow_op: | ||
# If the node is a control flow op, we skip it. | ||
logger.debug("Skipping control flow op %s", node) | ||
continue | ||
|
||
if _is_non_deterministic_op(node): | ||
# If the node is a non-deterministic op, we skip it. | ||
logger.debug("Skipping non-deterministic op %s", node) | ||
continue | ||
|
||
node_info = ( | ||
node.op_identifier(), | ||
len(node.outputs), | ||
tuple(id(input) for input in node.inputs), | ||
tuple(sorted(attributes.items())), | ||
) | ||
# Check if the node is a common subexpression. | ||
if node_info in existing_node_info_to_the_node: | ||
# If it is, this node has an existing node with the same | ||
# operator, number of outputs, inputs, and attributes. | ||
# We replace the node with the existing node. | ||
modified = True | ||
existing_node = existing_node_info_to_the_node[node_info] | ||
_remove_node_and_replace_values( | ||
graph, | ||
remove_node=node, | ||
remove_values=node.outputs, | ||
new_values=existing_node.outputs, | ||
) | ||
logger.debug("Reusing node %s", existing_node) | ||
else: | ||
# If it is not, add to the mapping. | ||
existing_node_info_to_the_node[node_info] = node | ||
return modified | ||
|
||
|
||
def _remove_node_and_replace_values( | ||
graph: ir.Graph, | ||
/, | ||
remove_node: ir.Node, | ||
remove_values: Sequence[ir.Value], | ||
new_values: Sequence[ir.Value], | ||
) -> None: | ||
"""Replaces nodes and values in the graph or function. | ||
|
||
Args: | ||
graph: The graph to replace nodes and values in. | ||
remove_node: The node to remove. | ||
remove_values: The values to replace. | ||
new_values: The values to replace with. | ||
""" | ||
# Reconnect the users of the deleted values to use the new values | ||
ir.convenience.replace_all_uses_with(remove_values, new_values) | ||
# Update graph/function outputs if the node generates output | ||
if any(remove_value.is_graph_output() for remove_value in remove_values): | ||
replacement_mapping = dict(zip(remove_values, new_values)) | ||
for idx, graph_output in enumerate(graph.outputs): | ||
if graph_output in replacement_mapping: | ||
new_value = replacement_mapping[graph_output] | ||
if new_value.is_graph_output() or new_value.is_graph_input(): | ||
# If the new value is also a graph input/output, we need to | ||
# create a Identity node to preserve the remove_value and | ||
# prevent from changing new_value name. | ||
identity_node = ir.node( | ||
"Identity", | ||
inputs=[new_value], | ||
outputs=[ | ||
ir.Value( | ||
name=graph_output.name, | ||
type=graph_output.type, | ||
shape=graph_output.shape, | ||
) | ||
], | ||
) | ||
# reuse the name of the graph output | ||
graph.outputs[idx] = identity_node.outputs[0] | ||
graph.insert_before( | ||
remove_node, | ||
identity_node, | ||
) | ||
else: | ||
# if new_value is not graph output, we just | ||
# update it to use old_value name. | ||
new_value.name = graph_output.name | ||
graph.outputs[idx] = new_value | ||
|
||
graph.remove(remove_node, safe=True) | ||
|
||
|
||
def _is_non_deterministic_op(node: ir.Node) -> bool: | ||
non_deterministic_ops = frozenset( | ||
{ | ||
"RandomUniform", | ||
"RandomNormal", | ||
"RandomUniformLike", | ||
"RandomNormalLike", | ||
"Multinomial", | ||
} | ||
) | ||
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain) | ||
|
||
|
||
def _is_onnx_domain(d: str) -> bool: | ||
"""Check if the domain is the ONNX domain.""" | ||
return d == "" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.