Skip to content

Commit d620466

Browse files
authored
Refactors onnxscript.rewriter.onnxruntime.rewrite to call onnx.rewriter.rewrite (#1628)
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
1 parent dc31a6e commit d620466

File tree

5 files changed

+14
-23
lines changed

5 files changed

+14
-23
lines changed

docs/api/tools.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
```
88

99
```{eval-rst}
10-
.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_config
10+
.. autofunction:: onnxscript.tools.transformers_models.phi.get_phi_model_from_config
1111
```
1212

1313
```{eval-rst}
14-
.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_config
14+
.. autofunction:: onnxscript.tools.transformers_models.llama.get_llama_model_from_config
1515
```

onnxscript/_legacy_ir/visitor.py

+2
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ def get_constant_value(i: int) -> onnx.TensorProto | None:
590590
)
591591

592592
for output in node.output:
593+
if output == "":
594+
continue
593595
info = self.lookup_or_create(output)
594596
if output in output_types:
595597
if info.type is not None:

onnxscript/optimizer/remove_unused_ir.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def is_used_output(i: int) -> bool:
3232

3333
if is_used_output(1) or is_used_output(2):
3434
return
35-
node.outputs[1].name = ""
36-
node.outputs[2].name = ""
35+
if len(node.outputs) > 1:
36+
node.outputs[1].name = ""
37+
if len(node.outputs) > 2:
38+
node.outputs[2].name = ""
3739
node.attributes.pop("training_mode", None)
3840
return
3941

onnxscript/rewriter/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def rewrite(
4040
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
4141
count = pattern_rewrite_rules.apply_to_model(model_ir)
4242
print(f"Applied {count} of general pattern rewrite rules.")
43+
remove_unused.remove_unused_nodes(model_ir)
44+
model_ir = remove_unused_function.remove_unused_functions(model_ir)
4345
model = ir.serde.serialize_model(model_ir)
44-
remove_unused.remove_unused_nodes(model)
45-
model = remove_unused_function.remove_unused_functions(model)
4646
return model

onnxscript/rewriter/onnxruntime/__init__.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44

55
import onnx
66

7-
from onnxscript import ir
8-
from onnxscript.optimizer import remove_unused, remove_unused_function
97
from onnxscript.rewriter import function_rule, pattern
8+
from onnxscript.rewriter import rewrite as _rewrite
109
from onnxscript.rewriter.onnxruntime import (
1110
group_normalization_merge_silu,
1211
instance_to_group_normalization,
@@ -44,18 +43,6 @@ def rewrite(
4443
"""
4544
function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES
4645
pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES
47-
model = ir.serde.deserialize_model(model_proto)
48-
# TODO(bowenbao): Function rules first, or pattern rules first?
49-
if function_rules:
50-
for rule_cls in function_rules:
51-
count, model = rule_cls().apply_to_model(model)
52-
if count > 0:
53-
print(f"Applied {count} of onnxruntime specific function rewrite rules.")
54-
if pattern_rules:
55-
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model)
56-
print(f"Applied {count} of onnxruntime specific pattern rewrite rules.")
57-
58-
model_proto = ir.serde.serialize_model(model)
59-
remove_unused.remove_unused_nodes(model_proto)
60-
model_proto = remove_unused_function.remove_unused_functions(model_proto)
61-
return model_proto
46+
return _rewrite(
47+
model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules
48+
)

0 commit comments

Comments
 (0)