Skip to content

Commit d425f81

Browse files
committed
Update on "[ET-VK] Minor unroll tuning to improve conv2d pw perf."
This diff provides a minor unroll tuning to improve the performance of the conv2d pointwise (pw) operation in the Executorch Vulkan backend. Differential Revision: [D75420510](https://our.internmc.facebook.com/intern/diff/D75420510/) [ghstack-poisoned]
2 parents 686d7ae + f27da20 commit d425f81

27 files changed

+975
-198
lines changed

.github/workflows/apple-perf.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,7 @@ jobs:
386386
echo "::endgroup::"
387387
388388
echo "::group::Build ExecuTorch iOS frameworks"
389-
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
390-
scripts/build_apple_frameworks.sh --Release --Debug --coreml --custom --mps --optimized --portable --quantized --xnnpack
389+
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh
391390
echo "::endgroup::"
392391
393392
# NB: Although exported models can be copied to this directory and bundled together with the

.github/workflows/apple.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ jobs:
173173
backends/apple/mps/install_requirements.sh
174174
175175
# Build iOS Frameworks
176-
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
177-
scripts/build_apple_frameworks.sh --Release --Debug --coreml --custom --mps --optimized --portable --quantized --xnnpack
176+
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh
178177
179178
# Bundle iOS Frameworks
180179
for FRAMEWORK in "${FRAMEWORKS[@]}"; do (
@@ -314,8 +313,7 @@ jobs:
314313
echo "::endgroup::"
315314
316315
echo "::group::Build ExecuTorch iOS frameworks"
317-
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
318-
scripts/build_apple_frameworks.sh --Release --Debug --coreml --custom --mps --optimized --portable --quantized --xnnpack
316+
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh
319317
echo "::endgroup::"
320318
321319
echo "::group::Build ExecuTorch benchmark app"

backends/qualcomm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
78
from .annotate_quant_attrs import AnnotateQuantAttrs
89
from .annotate_stack import AnnotateStack
910
from .annotate_unbind import AnnotateUnbind
@@ -16,6 +17,7 @@
1617
from .decompose_einsum import DecomposeEinsum
1718
from .decompose_expm1 import DecomposeExpM1
1819
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
20+
from .decompose_roll import DecomposeRoll
1921
from .decompose_silu import DecomposeSilu
2022
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2123
from .fixed_linear_keep_dim import FixedLinearKeepDim
@@ -39,6 +41,7 @@
3941

4042

4143
__all__ = [
44+
AnnotateAdaptiveAvgPool1D,
4245
AnnotateQuantAttrs,
4346
AnnotateStack,
4447
AnnotateUnbind,
@@ -51,6 +54,7 @@
5154
DecomposeEinsum,
5255
DecomposeExpM1,
5356
DecomposeLinalgVectorNorm,
57+
DecomposeRoll,
5458
DecomposeSilu,
5559
ExpandBroadcastTensorShape,
5660
FixedLinearKeepDim,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
8+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
11+
12+
from .utils import get_quant_attrs
13+
14+
15+
class AnnotateAdaptiveAvgPool1D(ExportPass):
16+
"""
17+
Add "quant_attrs" to graph nodes' meta from the QDQ information
18+
generated after quantization process.
19+
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
20+
"""
21+
22+
decomp_ops = [torch.ops.aten.adaptive_avg_pool2d.default]
23+
24+
def __init__(self, edge_program: torch.export.ExportedProgram):
25+
super(AnnotateAdaptiveAvgPool1D, self).__init__()
26+
self.edge_program = edge_program
27+
28+
def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
29+
partitions = get_source_partitions(
30+
graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default]
31+
)
32+
for src_partitions in partitions.values():
33+
for src_partition in src_partitions:
34+
output = src_partition.output_nodes[0]
35+
if (list(output.users)[0].target) in q_ops:
36+
quant_attrs = get_quant_attrs(
37+
self.edge_program, list(output.users)[0]
38+
)
39+
for n in src_partition.nodes:
40+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
41+
42+
def call(self, graph_module: torch.fx.GraphModule):
43+
self._annotate_adaptive_avg_pool1d(graph_module)
44+
graph_module.recompile()
45+
return PassResult(graph_module, True)

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88

99
import torch
10+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
1011
from executorch.backends.qualcomm.builders.utils import get_parameter
1112
from executorch.backends.qualcomm.utils.constants import (
1213
QCOM_DTYPE,
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.exir.pass_base import ExportPass, PassResult
2223

23-
from .utils import dq_ops, get_quant_attrs, q_ops
24+
from .utils import get_quant_attrs
2425

2526

2627
class AnnotateQuantAttrs(ExportPass):

backends/qualcomm/_passes/annotate_stack.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import get_quant_attrs, q_ops
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateStack(ExportPass):

backends/qualcomm/_passes/annotate_unbind.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import dq_ops, get_quant_attrs
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateUnbind(ExportPass):
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_nn_module_stack
11+
12+
13+
class SliceCopy(torch.nn.Module):
14+
def __init__(self, val_shape, shifts, dims):
15+
super().__init__()
16+
self.val_shape = val_shape
17+
if dims[0] is None:
18+
self.shifts = [shifts[0] % torch.numel(torch.tensor(val_shape))]
19+
else:
20+
self.shifts = [shift % val_shape[dim] for shift, dim in zip(shifts, dims)]
21+
self.dims = dims
22+
23+
def forward(self, x):
24+
if self.dims[0] is None:
25+
y = x.flatten()
26+
y = torch.cat((y[-self.shifts[0] :], y[: -self.shifts[0]]))
27+
return y.view(self.val_shape)
28+
29+
for shift, dim in zip(self.shifts, self.dims):
30+
x = torch.cat(
31+
(
32+
x[(slice(None),) * dim + (slice(-shift, None),)],
33+
x[(slice(None),) * dim + (slice(0, -shift),)],
34+
),
35+
dim=dim,
36+
)
37+
return x
38+
39+
40+
class DecomposeRoll(ExportPass):
41+
"""
42+
Decompose roll into slice and cat.
43+
"""
44+
45+
def __init__(self) -> None:
46+
super().__init__()
47+
48+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
49+
graph = graph_module.graph
50+
for node in graph.nodes:
51+
if "roll" in str(node.target):
52+
input_node, shifts = node.args[0], node.args[1]
53+
dims = node.args[2] if len(node.args) == 3 else None
54+
55+
# Normalize shifts and dims to lists
56+
shifts = shifts if isinstance(shifts, (list, tuple)) else [shifts]
57+
dims = dims if isinstance(dims, (list, tuple)) else [dims]
58+
59+
model = SliceCopy(input_node.meta["val"].shape, shifts, dims)
60+
decomposed_module = torch.export.export(
61+
model, (input_node.meta["val"],), strict=True
62+
).module()
63+
64+
with graph.inserting_before(node):
65+
# remap is used to map original node values to new node values,
66+
# which ensures that reference to nodes are correctly updated in the new graph
67+
remap = {"x": input_node}
68+
69+
for decomposed_node in decomposed_module.graph.nodes:
70+
copy_nn_module_stack(node, decomposed_node)
71+
# no need to copy existent 'output'
72+
if decomposed_node.op == "output":
73+
for user in node.users.copy():
74+
# remap
75+
user.replace_input_with(
76+
node,
77+
remap[decomposed_node.args[0][0]],
78+
)
79+
# no need to copy existent placeholders
80+
elif decomposed_node.op == "placeholder":
81+
# replace node map from string to graph node
82+
remap[decomposed_node] = remap.pop(decomposed_node.name)
83+
else:
84+
remap[decomposed_node] = graph.node_copy(
85+
decomposed_node,
86+
arg_transform=lambda x, remap=remap: remap[x],
87+
)
88+
89+
graph.erase_node(node)
90+
91+
graph.eliminate_dead_code()
92+
graph_module.recompile()
93+
return PassResult(graph_module, True)

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
89
from executorch.exir.dialects._ops import ops as exir_ops
910
from executorch.exir.pass_base import ExportPass, PassResult
1011
from executorch.exir.passes import dead_code_elimination_pass
1112

12-
from .utils import dq_ops
13-
1413

1514
class ExpandBroadcastTensorShape(ExportPass):
1615
"""

backends/qualcomm/_passes/fold_qdq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
78
from executorch.backends.qualcomm.builders.utils import is_parameter
89
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
910
from executorch.exir.dialects._ops import ops as exir_ops
1011
from executorch.exir.pass_base import ExportPass, PassResult
1112
from executorch.exir.passes import dead_code_elimination_pass
1213

13-
from .utils import dq_ops, q_ops
14-
1514

1615
class FoldQDQ(ExportPass):
1716
"""

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99

10+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
11+
1012
from executorch.backends.qualcomm.builders.utils import is_parameter
1113
from executorch.backends.qualcomm.utils.constants import (
1214
QCOM_ENCODING,
@@ -16,8 +18,6 @@
1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from executorch.exir.pass_base import ExportPass, PassResult
1820

19-
from .utils import q_ops
20-
2121

2222
class InsertIOQDQ(ExportPass):
2323
"""

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TensorOpInfo:
5050
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
5151
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
5252
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
53+
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
5354
}
5455

5556

@@ -78,7 +79,7 @@ def _build_tensor_constant(
7879
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
7980
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
8081
tensor = torch.tensor(
81-
[const_val],
82+
const_val,
8283
dtype=(
8384
node.args[0].meta["val"].dtype
8485
if not is_float_tensor(node)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict
1010

1111
from executorch.backends.qualcomm._passes import (
12+
AnnotateAdaptiveAvgPool1D,
1213
AnnotateQuantAttrs,
1314
AnnotateStack,
1415
AnnotateUnbind,
@@ -21,6 +22,7 @@
2122
DecomposeEinsum,
2223
DecomposeExpM1,
2324
DecomposeLinalgVectorNorm,
25+
DecomposeRoll,
2426
DecomposeSilu,
2527
ExpandBroadcastTensorShape,
2628
FixedLinearKeepDim,
@@ -74,6 +76,7 @@ def get_capture_program_passes():
7476
# The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
7577
# If a pass is activated, it will be executed by default.
7678
default_passes_and_setting = [
79+
(AnnotateAdaptiveAvgPool1D, True),
7780
(AnnotateQuantAttrs, True),
7881
(AnnotateStack, True),
7982
(AnnotateUnbind, True),
@@ -129,11 +132,11 @@ def get_to_edge_transform_passes(
129132
dep_table: Dict = None,
130133
):
131134
# TODO: remove this workaround when target could be correctly detected
132-
from executorch.backends.qualcomm._passes import utils
135+
from executorch.backends.qualcomm.builders import node_visitor
133136
from executorch.exir.dialects._ops import ops as exir_ops
134137

135-
utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
136-
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
138+
node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
139+
node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
137140

138141
passes_job = (
139142
passes_job if passes_job is not None else get_capture_program_passes()
@@ -189,6 +192,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
189192
self.add_pass(ReplaceArangeArgs())
190193
self.add_pass(DecomposeCDist())
191194
self.add_pass(DecomposeScaledDotProductAttention())
195+
self.add_pass(DecomposeRoll())
192196
self.add_pass(DecomposeSilu())
193197
self.add_pass(DecomposeEinsum())
194198
self.add_pass(DecomposeExpM1())
@@ -200,6 +204,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200204
def transform_for_export_pipeline(self, exported_program: ExportedProgram):
201205
self.add_pass(DecomposeCDist())
202206
self.add_pass(DecomposeScaledDotProductAttention())
207+
self.add_pass(DecomposeRoll())
203208
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
204209
self.add_pass(DecomposeExpM1())
205210
# this pass will rewrite state_dict, it needs to be accomplished before

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
79
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass, PassResult
1012
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1113

12-
from .utils import dq_ops
13-
1414

1515
class RecomposeRmsNorm(ExportPass):
1616
"""

0 commit comments

Comments
 (0)