Skip to content

Commit 1dffe60

Browse files
kunal-vaishnavijatinwadhwa921
authored andcommitted
Add fusions for SigLIP and Conformer-Encoder (microsoft#23528)
### Description This PR adds fusions for [Google's SigLIP model](https://huggingface.co/google/siglip-base-patch16-224/) and Microsoft's internal conformer-encoder model. Here is an example of how to run the ORT transformer optimizer for the SigLIP model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1152 --use_external_data_format --opt_level 0 --disable_shape_inference ``` Here is an example of how to run the ORT transformer optimizer for the conformer-encoder model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type conformer --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 --disable_shape_inference --convert_attribute ``` ### Motivation and Context This PR helps optimize multi-modal models that use SigLIP for the vision encoder and conformer-encoder for the speech encoder. This PR uses changes from the following PRs: - pytorch/pytorch#144801 - microsoft/onnxscript#2018 - microsoft/onnxscript#2019 - microsoft/onnxscript#2020 - microsoft/onnxscript#2021 - microsoft/onnxscript#2022 - microsoft/onnxscript#2024 - microsoft/onnxscript#2025 - microsoft/onnxscript#2029 - microsoft/onnxscript#2033 ### Introduction of ONNX Script This PR introduces [ONNX Script](https://github.com/microsoft/onnxscript) into the ORT transformer optimizer as an optional step via the `fold_transpose_initializers()` method of the `DynamoOnnxHelper` class.
1 parent e5cafd9 commit 1dffe60

19 files changed

+677
-154
lines changed

onnxruntime/python/tools/transformers/dynamo_onnx_helper.py

Lines changed: 117 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,61 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
import logging
5+
from collections.abc import Sequence
6+
from logging import getLogger
7+
from typing import Any
68

9+
import numpy as np
710
import onnx
11+
from onnx import helper
12+
from onnx_model import OnnxModel
13+
14+
logger = getLogger(__name__)
815

916

1017
class DynamoOnnxHelper:
1118
"""
12-
Helper class for processing ONNX models exported by torch Dynamo.
19+
Helper class for processing ONNX models exported by Torch Dynamo.
1320
"""
1421

1522
def __init__(self, model: onnx.ModelProto):
16-
self.model = model
23+
self.model = OnnxModel(model)
1724

1825
def update_edges(self, edge_mapping: dict) -> None:
1926
"""
2027
Updates the edges in the model according to the given mapping.
2128
"""
22-
for node in self.model.graph.node:
29+
for node in self.model.model.graph.node:
2330
for i in range(len(node.input)):
2431
if node.input[i] in edge_mapping:
2532
node.input[i] = edge_mapping[node.input[i]]
2633
for i in range(len(node.output)):
2734
if node.output[i] in edge_mapping:
2835
node.output[i] = edge_mapping[node.output[i]]
2936

30-
for graph_input in self.model.graph.input:
37+
for graph_input in self.model.model.graph.input:
3138
if graph_input.name in edge_mapping:
3239
graph_input.name = edge_mapping[graph_input.name]
33-
for graph_output in self.model.graph.output:
40+
for graph_output in self.model.model.graph.output:
3441
if graph_output.name in edge_mapping:
3542
graph_output.name = edge_mapping[graph_output.name]
3643

3744
def unroll_function(self, func_name: str) -> None:
3845
"""
3946
Unrolls the function with the given name in the model.
4047
"""
41-
logging.info(f"Unrolling function {func_name}...")
48+
logger.debug(f"Unrolling function {func_name}...")
4249
nodes_to_remove = []
4350
nodes_to_add = []
4451
edges_to_remove = []
4552
edges_to_add = []
46-
for node in self.model.graph.node:
53+
for node in self.model.model.graph.node:
4754
if node.op_type == func_name:
4855
nodes_to_remove.append(node)
4956
edges_to_remove.extend(list(node.input) + list(node.output))
5057

5158
func_to_remove = None
52-
for f in self.model.functions:
59+
for f in self.model.model.functions:
5360
if f.name == func_name:
5461
nodes_to_add.extend(list(f.node))
5562
edges_to_add.extend(list(f.input) + list(f.output))
@@ -58,11 +65,11 @@ def unroll_function(self, func_name: str) -> None:
5865
assert len(edges_to_remove) == len(edges_to_add)
5966

6067
for node in nodes_to_remove:
61-
self.model.graph.node.remove(node)
68+
self.model.model.graph.node.remove(node)
6269
for node in nodes_to_add:
63-
self.model.graph.node.append(node)
70+
self.model.model.graph.node.append(node)
6471
if func_to_remove is not None:
65-
self.model.functions.remove(func_to_remove)
72+
self.model.model.functions.remove(func_to_remove)
6673

6774
edge_mapping = {}
6875
for i in range(len(edges_to_remove)):
@@ -79,26 +86,120 @@ def remove_function(self, func_name: str, input_id: int, output_id: int) -> None
7986
"""
8087
edge_mapping = {}
8188
nodes_to_remove = []
82-
for node in self.model.graph.node:
89+
for node in self.model.model.graph.node:
8390
if node.op_type.find(func_name) != -1:
8491
edge_mapping[node.input[input_id]] = node.output[output_id]
8592
nodes_to_remove.append(node)
8693
for node in nodes_to_remove:
87-
self.model.graph.node.remove(node)
94+
self.model.model.graph.node.remove(node)
8895

8996
self.update_edges(edge_mapping)
9097

9198
def remove_dropout_layer(self) -> None:
9299
"""
93100
Removes the dropout layer in the model.
94101
"""
95-
logging.info("Removing dropout layer...")
102+
logger.debug("Removing dropout layer...")
96103
self.remove_function("Dropout", 0, 0)
97104

98105
def remove_lm_head_layer(self) -> None:
99106
"""
100107
Removes the LM head layer in the model.
101108
"""
102-
logging.info("Removing LM head layer...")
109+
logger.debug("Removing LM head layer...")
103110
# bugbug: need to copy the right vi over
104111
self.remove_function("Linear_lm_head", 2, 0)
112+
113+
def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
114+
if raw:
115+
np_type = helper.tensor_dtype_to_np_dtype(data_type)
116+
if not isinstance(vals, np.ndarray):
117+
bytes = np.array(vals, dtype=np_type).tobytes()
118+
else:
119+
bytes = vals.astype(np_type).tobytes()
120+
tensor = helper.make_tensor(
121+
name=name,
122+
data_type=data_type,
123+
dims=dims,
124+
vals=bytes,
125+
raw=True,
126+
)
127+
else:
128+
tensor = helper.make_tensor(
129+
name=name,
130+
data_type=data_type,
131+
dims=dims,
132+
vals=vals,
133+
raw=False,
134+
)
135+
136+
self.model.add_initializer(tensor)
137+
return tensor
138+
139+
def convert_constants_to_initializers(self, min_size: int = 1) -> None:
140+
"""
141+
Converts Constant ops of size [min_size] or higher to initializers
142+
"""
143+
logger.debug(f"Converting constants greater than size {min_size} to initializers")
144+
145+
constant_nodes = self.model.get_nodes_by_op_type("Constant")
146+
nodes_to_remove = []
147+
148+
for node in constant_nodes:
149+
# Get info from Constant op
150+
np_data = self.model.get_constant_value(node.output[0])
151+
152+
# Skip if there are less than [min_size] elements
153+
if np_data is None or np_data.size < min_size:
154+
continue
155+
156+
# Add new initializer with same name as Constant op's output
157+
for att in node.attribute:
158+
if att.name == "value":
159+
self.add_initializer(
160+
name=node.output[0],
161+
data_type=att.t.data_type,
162+
dims=list(np_data.shape),
163+
vals=np_data,
164+
)
165+
break
166+
167+
nodes_to_remove.append(node)
168+
169+
# Remove Constant ops from graph
170+
self.model.remove_nodes(nodes_to_remove)
171+
172+
def clear_metadata(self) -> None:
173+
"""
174+
Clear metadata fields in all nodes
175+
"""
176+
for graph in self.model.graphs():
177+
graph.ClearField("metadata_props")
178+
for node in self.model.nodes():
179+
node.ClearField("metadata_props")
180+
181+
@staticmethod
182+
def fold_transpose_initializers(model) -> None:
183+
"""
184+
Constant fold Transpose initializers without changing the initializer names
185+
"""
186+
from onnxscript import ir
187+
188+
for name, initializer in model.graph.initializers.items():
189+
user_nodes = initializer.consumers()
190+
if len(user_nodes) == 1 and user_nodes[0].op_type == "Transpose":
191+
transpose_node = user_nodes[0]
192+
perm = transpose_node.attributes.get("perm")
193+
if perm is None:
194+
transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose())
195+
else:
196+
transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose(perm.as_ints()))
197+
new_initializer = ir.Value(
198+
name=initializer.name,
199+
shape=transposed_tensor.shape,
200+
type=ir.TensorType(transposed_tensor.dtype),
201+
const_value=transposed_tensor,
202+
)
203+
ir.convenience.replace_all_uses_with(transpose_node.outputs[0], new_initializer)
204+
model.graph.initializers[name] = new_initializer
205+
transpose_node.graph.remove(transpose_node, safe=True)

onnxruntime/python/tools/transformers/fusion_attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -838,12 +838,10 @@ def create_attention_node(
838838
attention_inputs.append(past_kv)
839839

840840
if add_qk_str:
841-
mask_output_name = self.reshape_add_qk(add_qk_str)
842-
843-
# Add attention mask to attention node
841+
# Add additional add to attention node (input name = attention_bias)
844842
if not past_exists:
845843
attention_inputs.append("")
846-
attention_inputs.append(mask_output_name)
844+
attention_inputs.append(add_qk_str)
847845

848846
attention_outputs = [output]
849847
if present_k and present_v:

0 commit comments

Comments
 (0)