Skip to content

Commit a5ed079

Browse files
Add verbose support in single-output pattern-matcher (#1555)
Next step in unifying the two pattern-matchers: * Refactor the pattern-matching algorithm out of the pattern-IR classes * Add support for verbose-flag: will print info about status during algorithm * Unify the constructors for rewrite-rule --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 69ae7f4 commit a5ed079

File tree

5 files changed

+243
-122
lines changed

5 files changed

+243
-122
lines changed

examples/pattern_rewriting.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import onnx.numpy_helper as onh
1515

1616
from onnxscript import ir
17-
from onnxscript.rewriter import generic_pattern
17+
from onnxscript.rewriter import pattern
1818

1919

2020
def get_rotary_model(bad_model=False):
@@ -99,9 +99,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
9999
#
100100
# The rule is easy to create.
101101

102-
rule = generic_pattern.make_pattern_rule(
103-
rotary_match_pattern, rotary_apply_pattern, verbose=10
104-
)
102+
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
105103

106104
##########################
107105
# Let's apply it.
@@ -136,9 +134,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
136134
# The match did not happen.
137135
# Let's increase the verbosity.
138136

139-
rule = generic_pattern.make_pattern_rule(
140-
rotary_match_pattern, rotary_apply_pattern, verbose=10
141-
)
137+
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
142138

143139
rule.apply_to_model(ir_model)
144140

onnxscript/rewriter/generic_pattern.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import os
66
import textwrap
7+
import warnings
78
from typing import Any, Callable, Iterator, Sequence
89

910
import onnxscript.rewriter.pattern as orp
@@ -79,7 +80,7 @@ def _to_match_result(pmr: PatternMatchResult) -> orp.MatchResult:
7980
8081
TODO: This is a temporary hack until MatchResult and PatternMatchResult are unified.
8182
"""
82-
result = orp.MatchResult(success=True)
83+
result = orp.MatchResult()
8384
result.nodes.extend(pmr.model_nodes)
8485
for var, val in pmr.matched_pattern_to_model_value.items():
8586
if var.name is not None:
@@ -633,6 +634,11 @@ def make_pattern_rule(
633634
the rewriting rule
634635
"""
635636

637+
warnings.warn(
638+
"make_pattern_rule(...) is deprecated, use pattern.RewriteRule(...) instead",
639+
FutureWarning,
640+
stacklevel=2,
641+
)
636642
pattern = orp._to_graph_pattern(match_pattern_function)
637643
matcher = GenericPatternMatcher(pattern)
638644
return orp.RewriteRule(

onnxscript/rewriter/generic_pattern_test.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import onnxruntime as ort
1212

1313
from onnxscript import ir
14-
from onnxscript.rewriter import generic_pattern
14+
from onnxscript.rewriter import generic_pattern, pattern
1515

1616
FLOAT = onnx.TensorProto.FLOAT
1717

@@ -41,8 +41,11 @@ def validate_mapping(context, x, y, z, **_) -> bool:
4141
del context
4242
return True
4343

44-
rule = generic_pattern.make_pattern_rule(
45-
match_pattern, apply_pattern, validate_mapping
44+
rule = pattern.RewriteRule(
45+
match_pattern,
46+
apply_pattern,
47+
validate_mapping,
48+
generic_pattern.GenericPatternMatcher,
4649
)
4750

4851
class AddAdd(onnx.reference.op_run.OpRun):
@@ -118,8 +121,12 @@ def apply_pattern(op, x, y, w, z, **_):
118121
def validate_mapping(context, **_) -> bool:
119122
return True
120123

121-
rule = generic_pattern.make_pattern_rule(
122-
match_pattern, apply_pattern, validate_mapping, verbose=10
124+
rule = pattern.RewriteRule(
125+
match_pattern,
126+
apply_pattern,
127+
validate_mapping,
128+
generic_pattern.GenericPatternMatcher,
129+
verbose=10,
123130
)
124131

125132
class AddAddAddAdd(onnx.reference.op_run.OpRun):
@@ -284,8 +291,12 @@ def apply_pattern(op, x, pos_ids, axis, **_):
284291
outputs=2,
285292
)
286293

287-
rule = generic_pattern.make_pattern_rule(
288-
match_pattern, apply_pattern, validate_mapping, verbose=10
294+
rule = pattern.RewriteRule(
295+
match_pattern,
296+
apply_pattern,
297+
validate_mapping,
298+
generic_pattern.GenericPatternMatcher,
299+
verbose=10,
289300
)
290301

291302
model = self.get_rotary_model()
@@ -345,10 +356,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
345356
)
346357
return part1, part2
347358

348-
rule = generic_pattern.make_pattern_rule(
359+
rule = pattern.RewriteRule(
349360
rotary_match_pattern,
350361
rotary_apply_pattern,
351362
validate_rotary_mapping,
363+
generic_pattern.GenericPatternMatcher,
352364
verbose=10,
353365
)
354366

@@ -416,10 +428,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
416428
model = onnx.shape_inference.infer_shapes(model)
417429
ir_model = ir.serde.deserialize_model(model)
418430

419-
rule = generic_pattern.make_pattern_rule(
431+
rule = pattern.RewriteRule(
420432
rotary_match_pattern,
421433
rotary_apply_pattern,
422434
validate_rotary_mapping,
435+
generic_pattern.GenericPatternMatcher,
423436
verbose=10,
424437
)
425438

@@ -472,10 +485,11 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_):
472485
composed_perm = transpose_transpose_mapping(perm0, perm1)
473486
return op.Transpose(X, perm=composed_perm)
474487

475-
rule = generic_pattern.make_pattern_rule(
488+
rule = pattern.RewriteRule(
476489
transpose_transpose_pattern,
477490
transpose_transpose_apply_pattern,
478491
transpose_transpose_check,
492+
generic_pattern.GenericPatternMatcher,
479493
verbose=0,
480494
)
481495

0 commit comments

Comments
 (0)