|
11 | 11 | import onnxruntime as ort
|
12 | 12 |
|
13 | 13 | from onnxscript import ir
|
14 |
| -from onnxscript.rewriter import generic_pattern |
| 14 | +from onnxscript.rewriter import generic_pattern, pattern |
15 | 15 |
|
16 | 16 | FLOAT = onnx.TensorProto.FLOAT
|
17 | 17 |
|
@@ -41,8 +41,11 @@ def validate_mapping(context, x, y, z, **_) -> bool:
|
41 | 41 | del context
|
42 | 42 | return True
|
43 | 43 |
|
44 |
| - rule = generic_pattern.make_pattern_rule( |
45 |
| - match_pattern, apply_pattern, validate_mapping |
| 44 | + rule = pattern.RewriteRule( |
| 45 | + match_pattern, |
| 46 | + apply_pattern, |
| 47 | + validate_mapping, |
| 48 | + generic_pattern.GenericPatternMatcher, |
46 | 49 | )
|
47 | 50 |
|
48 | 51 | class AddAdd(onnx.reference.op_run.OpRun):
|
@@ -118,8 +121,12 @@ def apply_pattern(op, x, y, w, z, **_):
|
118 | 121 | def validate_mapping(context, **_) -> bool:
|
119 | 122 | return True
|
120 | 123 |
|
121 |
| - rule = generic_pattern.make_pattern_rule( |
122 |
| - match_pattern, apply_pattern, validate_mapping, verbose=10 |
| 124 | + rule = pattern.RewriteRule( |
| 125 | + match_pattern, |
| 126 | + apply_pattern, |
| 127 | + validate_mapping, |
| 128 | + generic_pattern.GenericPatternMatcher, |
| 129 | + verbose=10, |
123 | 130 | )
|
124 | 131 |
|
125 | 132 | class AddAddAddAdd(onnx.reference.op_run.OpRun):
|
@@ -284,8 +291,12 @@ def apply_pattern(op, x, pos_ids, axis, **_):
|
284 | 291 | outputs=2,
|
285 | 292 | )
|
286 | 293 |
|
287 |
| - rule = generic_pattern.make_pattern_rule( |
288 |
| - match_pattern, apply_pattern, validate_mapping, verbose=10 |
| 294 | + rule = pattern.RewriteRule( |
| 295 | + match_pattern, |
| 296 | + apply_pattern, |
| 297 | + validate_mapping, |
| 298 | + generic_pattern.GenericPatternMatcher, |
| 299 | + verbose=10, |
289 | 300 | )
|
290 | 301 |
|
291 | 302 | model = self.get_rotary_model()
|
@@ -345,10 +356,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_):
|
345 | 356 | )
|
346 | 357 | return part1, part2
|
347 | 358 |
|
348 |
| - rule = generic_pattern.make_pattern_rule( |
| 359 | + rule = pattern.RewriteRule( |
349 | 360 | rotary_match_pattern,
|
350 | 361 | rotary_apply_pattern,
|
351 | 362 | validate_rotary_mapping,
|
| 363 | + generic_pattern.GenericPatternMatcher, |
352 | 364 | verbose=10,
|
353 | 365 | )
|
354 | 366 |
|
@@ -416,10 +428,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
|
416 | 428 | model = onnx.shape_inference.infer_shapes(model)
|
417 | 429 | ir_model = ir.serde.deserialize_model(model)
|
418 | 430 |
|
419 |
| - rule = generic_pattern.make_pattern_rule( |
| 431 | + rule = pattern.RewriteRule( |
420 | 432 | rotary_match_pattern,
|
421 | 433 | rotary_apply_pattern,
|
422 | 434 | validate_rotary_mapping,
|
| 435 | + generic_pattern.GenericPatternMatcher, |
423 | 436 | verbose=10,
|
424 | 437 | )
|
425 | 438 |
|
@@ -472,10 +485,11 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_):
|
472 | 485 | composed_perm = transpose_transpose_mapping(perm0, perm1)
|
473 | 486 | return op.Transpose(X, perm=composed_perm)
|
474 | 487 |
|
475 |
| - rule = generic_pattern.make_pattern_rule( |
| 488 | + rule = pattern.RewriteRule( |
476 | 489 | transpose_transpose_pattern,
|
477 | 490 | transpose_transpose_apply_pattern,
|
478 | 491 | transpose_transpose_check,
|
| 492 | + generic_pattern.GenericPatternMatcher, |
479 | 493 | verbose=0,
|
480 | 494 | )
|
481 | 495 |
|
|
0 commit comments