Skip to content

Commit

Permalink
Add verbose support in single-output pattern-matcher (#1555)
Browse files Browse the repository at this point in the history
Next step in unifying the two pattern-matchers:

* Refactor the pattern-matching algorithm out of the pattern-IR classes
* Add support for verbose-flag: will print info about status during
algorithm
* Unify the constructors for rewrite-rule

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
  • Loading branch information
gramalingam and justinchuby authored May 20, 2024
1 parent 69ae7f4 commit a5ed079
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 122 deletions.
10 changes: 3 additions & 7 deletions examples/pattern_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import onnx.numpy_helper as onh

from onnxscript import ir
from onnxscript.rewriter import generic_pattern
from onnxscript.rewriter import pattern


def get_rotary_model(bad_model=False):
Expand Down Expand Up @@ -99,9 +99,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
#
# The rule is easy to create.

rule = generic_pattern.make_pattern_rule(
rotary_match_pattern, rotary_apply_pattern, verbose=10
)
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)

##########################
# Let's apply it.
Expand Down Expand Up @@ -136,9 +134,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
# The match did not happen.
# Let's increase the verbosity.

rule = generic_pattern.make_pattern_rule(
rotary_match_pattern, rotary_apply_pattern, verbose=10
)
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)

rule.apply_to_model(ir_model)

Expand Down
8 changes: 7 additions & 1 deletion onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import os
import textwrap
import warnings
from typing import Any, Callable, Iterator, Sequence

import onnxscript.rewriter.pattern as orp
Expand Down Expand Up @@ -79,7 +80,7 @@ def _to_match_result(pmr: PatternMatchResult) -> orp.MatchResult:
TODO: This is a temporary hack until MatchResult and PatternMatchResult are unified.
"""
result = orp.MatchResult(success=True)
result = orp.MatchResult()
result.nodes.extend(pmr.model_nodes)
for var, val in pmr.matched_pattern_to_model_value.items():
if var.name is not None:
Expand Down Expand Up @@ -633,6 +634,11 @@ def make_pattern_rule(
the rewriting rule
"""

warnings.warn(
"make_pattern_rule(...) is deprecated, use pattern.RewriteRule(...) instead",
FutureWarning,
stacklevel=2,
)
pattern = orp._to_graph_pattern(match_pattern_function)
matcher = GenericPatternMatcher(pattern)
return orp.RewriteRule(
Expand Down
34 changes: 24 additions & 10 deletions onnxscript/rewriter/generic_pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import onnxruntime as ort

from onnxscript import ir
from onnxscript.rewriter import generic_pattern
from onnxscript.rewriter import generic_pattern, pattern

FLOAT = onnx.TensorProto.FLOAT

Expand Down Expand Up @@ -41,8 +41,11 @@ def validate_mapping(context, x, y, z, **_) -> bool:
del context
return True

rule = generic_pattern.make_pattern_rule(
match_pattern, apply_pattern, validate_mapping
rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
)

class AddAdd(onnx.reference.op_run.OpRun):
Expand Down Expand Up @@ -118,8 +121,12 @@ def apply_pattern(op, x, y, w, z, **_):
def validate_mapping(context, **_) -> bool:
return True

rule = generic_pattern.make_pattern_rule(
match_pattern, apply_pattern, validate_mapping, verbose=10
rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
verbose=10,
)

class AddAddAddAdd(onnx.reference.op_run.OpRun):
Expand Down Expand Up @@ -284,8 +291,12 @@ def apply_pattern(op, x, pos_ids, axis, **_):
outputs=2,
)

rule = generic_pattern.make_pattern_rule(
match_pattern, apply_pattern, validate_mapping, verbose=10
rule = pattern.RewriteRule(
match_pattern,
apply_pattern,
validate_mapping,
generic_pattern.GenericPatternMatcher,
verbose=10,
)

model = self.get_rotary_model()
Expand Down Expand Up @@ -345,10 +356,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
)
return part1, part2

rule = generic_pattern.make_pattern_rule(
rule = pattern.RewriteRule(
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
generic_pattern.GenericPatternMatcher,
verbose=10,
)

Expand Down Expand Up @@ -416,10 +428,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
model = onnx.shape_inference.infer_shapes(model)
ir_model = ir.serde.deserialize_model(model)

rule = generic_pattern.make_pattern_rule(
rule = pattern.RewriteRule(
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
generic_pattern.GenericPatternMatcher,
verbose=10,
)

Expand Down Expand Up @@ -472,10 +485,11 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_):
composed_perm = transpose_transpose_mapping(perm0, perm1)
return op.Transpose(X, perm=composed_perm)

rule = generic_pattern.make_pattern_rule(
rule = pattern.RewriteRule(
transpose_transpose_pattern,
transpose_transpose_apply_pattern,
transpose_transpose_check,
generic_pattern.GenericPatternMatcher,
verbose=0,
)

Expand Down
Loading

0 comments on commit a5ed079

Please sign in to comment.