Skip to content

Improve handling of trailing optional inputs in pattern matching #1948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,19 +1040,14 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool:

self._matched[pattern_node] = node

# TODO: Revisit this to handle optional trailing inputs better.
if pattern_node.allow_other_inputs:
if len(node.inputs) < len(pattern_node.inputs):
return self.fail(
f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})"
)
else:
if len(node.inputs) != len(pattern_node.inputs):
return self.fail(
f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}"
)
if len(node.inputs) > len(pattern_node.inputs) and not pattern_node.allow_other_inputs:
return self.fail(
f"Number of inputs ({len(node.inputs)}) is more than expected ({len(pattern_node.inputs)})"
)

for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs):
for arg_value, arg_pattern in itertools.zip_longest(
node.inputs, pattern_node.inputs, fillvalue=None
):
# arg_pattern could be a Var, if it's the original arg.
if arg_pattern is None:
if arg_value is None:
Expand Down
32 changes: 32 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,38 @@
self.assertEqual(model.graph.node(0).op_type, "ReplacedNone")
self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone")

def test_match_trailing_optional_input(self):
def none_pattern(op, optional_input, x):
# match against a call to Original where the first input may or may not be None
return op.Original(x, optional_input)

def replacement(op, optional_input, x):
if optional_input is None:
return op.ReplacedNone(x)
return op.ReplacedNotNone(x)

rule = pattern.RewriteRule(none_pattern, replacement)

@script()
def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
# Pattern should match following call (with optional_input == None)
t1 = op.Original(x, None)
# as well as this one (with optional_input != None)
z = op.Original(x, t1)
# as well as this one (with optional_input == None)
z = op.Original(x)
return z

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)

count = rule.apply_to_model(model)
self.assertEqual(count, 3)
self.assertEqual(len(model.graph), 3)
self.assertEqual(model.graph.node(0).op_type, "ReplacedNone")
self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone")
self.assertEqual(model.graph.node(2).op_type, "ReplacedNone")


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
Loading