|
4 | 4 |
|
5 | 5 | import onnx
|
6 | 6 |
|
7 |
| -from onnxscript import ir |
8 |
| -from onnxscript.optimizer import remove_unused, remove_unused_function |
9 | 7 | from onnxscript.rewriter import function_rule, pattern
|
| 8 | +from onnxscript.rewriter import rewrite as _rewrite |
10 | 9 | from onnxscript.rewriter.onnxruntime import (
|
11 | 10 | group_normalization_merge_silu,
|
12 | 11 | instance_to_group_normalization,
|
@@ -44,18 +43,6 @@ def rewrite(
|
44 | 43 | """
|
45 | 44 | function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES
|
46 | 45 | pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES
|
47 |
| - model = ir.serde.deserialize_model(model_proto) |
48 |
| - # TODO(bowenbao): Function rules first, or pattern rules first? |
49 |
| - if function_rules: |
50 |
| - for rule_cls in function_rules: |
51 |
| - count, model = rule_cls().apply_to_model(model) |
52 |
| - if count > 0: |
53 |
| - print(f"Applied {count} of onnxruntime specific function rewrite rules.") |
54 |
| - if pattern_rules: |
55 |
| - count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model) |
56 |
| - print(f"Applied {count} of onnxruntime specific pattern rewrite rules.") |
57 |
| - |
58 |
| - model_proto = ir.serde.serialize_model(model) |
59 |
| - remove_unused.remove_unused_nodes(model_proto) |
60 |
| - model_proto = remove_unused_function.remove_unused_functions(model_proto) |
61 |
| - return model_proto |
| 46 | + return _rewrite( |
| 47 | + model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules |
| 48 | + ) |
0 commit comments