Skip to content

Commit 90686fa

Browse files
committed
Update on "[ET-VK] Remove the use of shared memory in conv2d pw to improve perf."
This diff removes the use of shared memory in the conv2d pw (pointwise) operation to improve performance. Differential Revision: [D75316188](https://our.internmc.facebook.com/intern/diff/D75316188/) [ghstack-poisoned]
2 parents 472026c + 28886cd commit 90686fa

27 files changed

+975
-198
lines changed

.github/workflows/apple-perf.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,7 @@ jobs:
386386
echo "::endgroup::"
387387
388388
echo "::group::Build ExecuTorch iOS frameworks"
389-
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
390-
scripts/build_apple_frameworks.sh --Release --Debug --coreml --custom --mps --optimized --portable --quantized --xnnpack
389+
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh
391390
echo "::endgroup::"
392391
393392
# NB: Although exported models can be copied to this directory and bundled together with the

.github/workflows/apple.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ jobs:
173173
backends/apple/mps/install_requirements.sh
174174
175175
# Build iOS Frameworks
176-
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
177-
scripts/build_apple_frameworks.sh --Release --Debug --coreml --custom --mps --optimized --portable --quantized --xnnpack
176+
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh
178177
179178
# Bundle iOS Frameworks
180179
for FRAMEWORK in "${FRAMEWORKS[@]}"; do (
@@ -314,8 +313,7 @@ jobs:
314313
echo "::endgroup::"
315314
316315
echo "::group::Build ExecuTorch iOS frameworks"
317-
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \
318-
scripts/build_apple_frameworks.sh --Release --Debug --coreml --custom --mps --optimized --portable --quantized --xnnpack
316+
PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh
319317
echo "::endgroup::"
320318
321319
echo "::group::Build ExecuTorch benchmark app"

backends/qualcomm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
78
from .annotate_quant_attrs import AnnotateQuantAttrs
89
from .annotate_stack import AnnotateStack
910
from .annotate_unbind import AnnotateUnbind
@@ -16,6 +17,7 @@
1617
from .decompose_einsum import DecomposeEinsum
1718
from .decompose_expm1 import DecomposeExpM1
1819
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
20+
from .decompose_roll import DecomposeRoll
1921
from .decompose_silu import DecomposeSilu
2022
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2123
from .fixed_linear_keep_dim import FixedLinearKeepDim
@@ -39,6 +41,7 @@
3941

4042

4143
__all__ = [
44+
AnnotateAdaptiveAvgPool1D,
4245
AnnotateQuantAttrs,
4346
AnnotateStack,
4447
AnnotateUnbind,
@@ -51,6 +54,7 @@
5154
DecomposeEinsum,
5255
DecomposeExpM1,
5356
DecomposeLinalgVectorNorm,
57+
DecomposeRoll,
5458
DecomposeSilu,
5559
ExpandBroadcastTensorShape,
5660
FixedLinearKeepDim,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
8+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
11+
12+
from .utils import get_quant_attrs
13+
14+
15+
class AnnotateAdaptiveAvgPool1D(ExportPass):
16+
"""
17+
Add "quant_attrs" to graph nodes' meta from the QDQ information
18+
generated after quantization process.
19+
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
20+
"""
21+
22+
decomp_ops = [torch.ops.aten.adaptive_avg_pool2d.default]
23+
24+
def __init__(self, edge_program: torch.export.ExportedProgram):
25+
super(AnnotateAdaptiveAvgPool1D, self).__init__()
26+
self.edge_program = edge_program
27+
28+
def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
29+
partitions = get_source_partitions(
30+
graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default]
31+
)
32+
for src_partitions in partitions.values():
33+
for src_partition in src_partitions:
34+
output = src_partition.output_nodes[0]
35+
if (list(output.users)[0].target) in q_ops:
36+
quant_attrs = get_quant_attrs(
37+
self.edge_program, list(output.users)[0]
38+
)
39+
for n in src_partition.nodes:
40+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
41+
42+
def call(self, graph_module: torch.fx.GraphModule):
43+
self._annotate_adaptive_avg_pool1d(graph_module)
44+
graph_module.recompile()
45+
return PassResult(graph_module, True)

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88

99
import torch
10+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
1011
from executorch.backends.qualcomm.builders.utils import get_parameter
1112
from executorch.backends.qualcomm.utils.constants import (
1213
QCOM_DTYPE,
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.exir.pass_base import ExportPass, PassResult
2223

23-
from .utils import dq_ops, get_quant_attrs, q_ops
24+
from .utils import get_quant_attrs
2425

2526

2627
class AnnotateQuantAttrs(ExportPass):

backends/qualcomm/_passes/annotate_stack.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import get_quant_attrs, q_ops
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateStack(ExportPass):

backends/qualcomm/_passes/annotate_unbind.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import dq_ops, get_quant_attrs
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateUnbind(ExportPass):
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_nn_module_stack
11+
12+
13+
class SliceCopy(torch.nn.Module):
14+
def __init__(self, val_shape, shifts, dims):
15+
super().__init__()
16+
self.val_shape = val_shape
17+
if dims[0] is None:
18+
self.shifts = [shifts[0] % torch.numel(torch.tensor(val_shape))]
19+
else:
20+
self.shifts = [shift % val_shape[dim] for shift, dim in zip(shifts, dims)]
21+
self.dims = dims
22+
23+
def forward(self, x):
24+
if self.dims[0] is None:
25+
y = x.flatten()
26+
y = torch.cat((y[-self.shifts[0] :], y[: -self.shifts[0]]))
27+
return y.view(self.val_shape)
28+
29+
for shift, dim in zip(self.shifts, self.dims):
30+
x = torch.cat(
31+
(
32+
x[(slice(None),) * dim + (slice(-shift, None),)],
33+
x[(slice(None),) * dim + (slice(0, -shift),)],
34+
),
35+
dim=dim,
36+
)
37+
return x
38+
39+
40+
class DecomposeRoll(ExportPass):
41+
"""
42+
Decompose roll into slice and cat.
43+
"""
44+
45+
def __init__(self) -> None:
46+
super().__init__()
47+
48+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
49+
graph = graph_module.graph
50+
for node in graph.nodes:
51+
if "roll" in str(node.target):
52+
input_node, shifts = node.args[0], node.args[1]
53+
dims = node.args[2] if len(node.args) == 3 else None
54+
55+
# Normalize shifts and dims to lists
56+
shifts = shifts if isinstance(shifts, (list, tuple)) else [shifts]
57+
dims = dims if isinstance(dims, (list, tuple)) else [dims]
58+
59+
model = SliceCopy(input_node.meta["val"].shape, shifts, dims)
60+
decomposed_module = torch.export.export(
61+
model, (input_node.meta["val"],), strict=True
62+
).module()
63+
64+
with graph.inserting_before(node):
65+
# remap is used to map original node values to new node values,
66+
# which ensures that reference to nodes are correctly updated in the new graph
67+
remap = {"x": input_node}
68+
69+
for decomposed_node in decomposed_module.graph.nodes:
70+
copy_nn_module_stack(node, decomposed_node)
71+
# no need to copy existent 'output'
72+
if decomposed_node.op == "output":
73+
for user in node.users.copy():
74+
# remap
75+
user.replace_input_with(
76+
node,
77+
remap[decomposed_node.args[0][0]],
78+
)
79+
# no need to copy existent placeholders
80+
elif decomposed_node.op == "placeholder":
81+
# replace node map from string to graph node
82+
remap[decomposed_node] = remap.pop(decomposed_node.name)
83+
else:
84+
remap[decomposed_node] = graph.node_copy(
85+
decomposed_node,
86+
arg_transform=lambda x, remap=remap: remap[x],
87+
)
88+
89+
graph.erase_node(node)
90+
91+
graph.eliminate_dead_code()
92+
graph_module.recompile()
93+
return PassResult(graph_module, True)

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
89
from executorch.exir.dialects._ops import ops as exir_ops
910
from executorch.exir.pass_base import ExportPass, PassResult
1011
from executorch.exir.passes import dead_code_elimination_pass
1112

12-
from .utils import dq_ops
13-
1413

1514
class ExpandBroadcastTensorShape(ExportPass):
1615
"""

backends/qualcomm/_passes/fold_qdq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
78
from executorch.backends.qualcomm.builders.utils import is_parameter
89
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
910
from executorch.exir.dialects._ops import ops as exir_ops
1011
from executorch.exir.pass_base import ExportPass, PassResult
1112
from executorch.exir.passes import dead_code_elimination_pass
1213

13-
from .utils import dq_ops, q_ops
14-
1514

1615
class FoldQDQ(ExportPass):
1716
"""

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99

10+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
11+
1012
from executorch.backends.qualcomm.builders.utils import is_parameter
1113
from executorch.backends.qualcomm.utils.constants import (
1214
QCOM_ENCODING,
@@ -16,8 +18,6 @@
1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from executorch.exir.pass_base import ExportPass, PassResult
1820

19-
from .utils import q_ops
20-
2121

2222
class InsertIOQDQ(ExportPass):
2323
"""

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TensorOpInfo:
5050
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
5151
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
5252
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
53+
aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False),
5354
}
5455

5556

@@ -78,7 +79,7 @@ def _build_tensor_constant(
7879
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
7980
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
8081
tensor = torch.tensor(
81-
[const_val],
82+
const_val,
8283
dtype=(
8384
node.args[0].meta["val"].dtype
8485
if not is_float_tensor(node)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict
1010

1111
from executorch.backends.qualcomm._passes import (
12+
AnnotateAdaptiveAvgPool1D,
1213
AnnotateQuantAttrs,
1314
AnnotateStack,
1415
AnnotateUnbind,
@@ -21,6 +22,7 @@
2122
DecomposeEinsum,
2223
DecomposeExpM1,
2324
DecomposeLinalgVectorNorm,
25+
DecomposeRoll,
2426
DecomposeSilu,
2527
ExpandBroadcastTensorShape,
2628
FixedLinearKeepDim,
@@ -74,6 +76,7 @@ def get_capture_program_passes():
7476
# The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
7577
# If a pass is activated, it will be executed by default.
7678
default_passes_and_setting = [
79+
(AnnotateAdaptiveAvgPool1D, True),
7780
(AnnotateQuantAttrs, True),
7881
(AnnotateStack, True),
7982
(AnnotateUnbind, True),
@@ -129,11 +132,11 @@ def get_to_edge_transform_passes(
129132
dep_table: Dict = None,
130133
):
131134
# TODO: remove this workaround when target could be correctly detected
132-
from executorch.backends.qualcomm._passes import utils
135+
from executorch.backends.qualcomm.builders import node_visitor
133136
from executorch.exir.dialects._ops import ops as exir_ops
134137

135-
utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
136-
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
138+
node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
139+
node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
137140

138141
passes_job = (
139142
passes_job if passes_job is not None else get_capture_program_passes()
@@ -189,6 +192,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
189192
self.add_pass(ReplaceArangeArgs())
190193
self.add_pass(DecomposeCDist())
191194
self.add_pass(DecomposeScaledDotProductAttention())
195+
self.add_pass(DecomposeRoll())
192196
self.add_pass(DecomposeSilu())
193197
self.add_pass(DecomposeEinsum())
194198
self.add_pass(DecomposeExpM1())
@@ -200,6 +204,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200204
def transform_for_export_pipeline(self, exported_program: ExportedProgram):
201205
self.add_pass(DecomposeCDist())
202206
self.add_pass(DecomposeScaledDotProductAttention())
207+
self.add_pass(DecomposeRoll())
203208
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
204209
self.add_pass(DecomposeExpM1())
205210
# this pass will rewrite state_dict, it needs to be accomplished before

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
79
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass, PassResult
1012
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1113

12-
from .utils import dq_ops
13-
1414

1515
class RecomposeRmsNorm(ExportPass):
1616
"""

0 commit comments

Comments
 (0)