Skip to content

Commit b3e6876

Browse files
shewu-quickirklandsign
authored andcommitted
Qualcomm AI Engine Direct - Add QNN support for to_edge_transform_and_lower (#9643)
Summary: - Support `to_edge_transform_and_lower` - Replace capture_program with new API `to_edge_transform_and_lower_to_qnn` - Replace capture_program with to_edge_transform_and_lower_to_qnn for unit_test - Replace capture_program with to_edge_transform_and_lower_to_qnn for examples - Replace capture_program with to_edge_transform_and_lower_to_qnn for llama - Add QnnPassManager to manage all passes in different stage - Deprecated _transform in export_llama_lib with qnn_pass_manager - Add transform_for_export_pipeline for LiftConstantScalarOperands to avoid creating temporary tensors in the operation builder. However, this pass will create a get_attr node, which should be converted into a lifted tensor constant by the lift_constant_tensor_pass. If placed in the to_edge_transform_passes, it will be executed after the lift_constant_tensor_pass, causing the operation builder to fail to correctly retrieve the parameter by the get_parameter for get_attr node. - Refactor the passes - Fix the output dtype doesn't match in runtime after build quant io - Combine constant_i64_to_i32 and tensor_i64_to_i32 into i64_to_i32 - Replace convert_to_linear pass with fixed_linear_keep_dim pass - Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node - Add TagQuantIO pass to tag io nodes to avoid inserting q/dq in qnn_preprocess - Add prelu, leaky_relu, linear, rms_norm into decompose_table - Remove recompose_prelu.py - Remove unused variable in insert_requantize.py, and replace_index_put_input.py - Support aten.split_with_sizes_copy.default - Support leaky_relu with inplace=True
1 parent 6be5933 commit b3e6876

36 files changed

+1063
-1231
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,51 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .annotate_decomposed import AnnotateDecomposed
87
from .annotate_quant_attrs import AnnotateQuantAttrs
9-
from .constant_i64_to_i32 import ConstantI64toI32
8+
from .annotate_stack import AnnotateStack
9+
from .annotate_unbind import AnnotateUnbind
1010
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1111
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
12-
from .convert_to_linear import ConvertToLinear
1312
from .decompose_any import DecomposeAny
1413
from .decompose_einsum import DecomposeEinsum
1514
from .decompose_expm1 import DecomposeExpM1
1615
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
1716
from .decompose_silu import DecomposeSilu
1817
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
18+
from .fixed_linear_keep_dim import FixedLinearKeepDim
1919
from .fold_qdq import FoldQDQ
2020
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
21+
from .i64_to_i32 import I64toI32
2122
from .insert_io_qdq import InsertIOQDQ
2223
from .insert_requantize import InsertRequantize
2324
from .layout_transform import LayoutTransform
2425
from .lift_constant_scalar_operands import LiftConstantScalarOperands
2526
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
26-
from .recompose_prelu import RecomposePReLU
2727
from .recompose_rms_norm import RecomposeRmsNorm
2828
from .reduce_dynamic_range import ReduceDynamicRange
2929
from .remove_redundancy import RemoveRedundancy
3030
from .replace_arange_args import ReplaceArangeArgs
3131
from .replace_index_put_input import ReplaceIndexPutInput
3232
from .replace_inf_values import ReplaceInfValues
33-
from .tensor_i64_to_i32 import TensorI64toI32
33+
from .tag_quant_io import TagQuantIO
3434

3535

3636
__all__ = [
37-
AnnotateDecomposed,
3837
AnnotateQuantAttrs,
39-
ConstantI64toI32,
38+
AnnotateStack,
39+
AnnotateUnbind,
4040
ConvertBmmToMatmul,
4141
ConvertConv1dToConv2d,
42-
RecomposePReLU,
43-
ConvertToLinear,
4442
DecomposeAny,
4543
DecomposeEinsum,
4644
DecomposeExpM1,
4745
DecomposeLinalgVectorNorm,
4846
DecomposeSilu,
4947
ExpandBroadcastTensorShape,
48+
FixedLinearKeepDim,
5049
FoldQDQ,
5150
FuseConsecutiveTranspose,
51+
I64toI32,
5252
InsertIOQDQ,
5353
InsertRequantize,
5454
LayoutTransform,
@@ -60,5 +60,5 @@
6060
ReplaceArangeArgs,
6161
ReplaceIndexPutInput,
6262
ReplaceInfValues,
63-
TensorI64toI32,
63+
TagQuantIO,
6464
]

backends/qualcomm/_passes/annotate_decomposed.py renamed to backends/qualcomm/_passes/annotate_stack.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,21 @@
88
from executorch.exir.pass_base import ExportPass, PassResult
99
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1010

11-
from .utils import dq_ops, get_quant_attrs, q_ops
11+
from .utils import get_quant_attrs, q_ops
1212

1313

14-
class AnnotateDecomposed(ExportPass):
14+
class AnnotateStack(ExportPass):
1515
"""
1616
Add "quant_attrs" to graph nodes' meta from the QDQ information
1717
generated after quantization process.
1818
"""
1919

20-
decomp_ops = [torch.ops.aten.stack.default, torch.ops.aten.unbind.int]
20+
decomp_ops = [torch.ops.aten.unbind.int]
2121

2222
def __init__(self, edge_program: torch.export.ExportedProgram):
23-
super(AnnotateDecomposed, self).__init__()
23+
super(AnnotateStack, self).__init__()
2424
self.edge_program = edge_program
2525

26-
def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
27-
partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"])
28-
for _, src_partitions in partitions.items():
29-
for src_partition in src_partitions:
30-
if src_partition.input_nodes[0].target in dq_ops:
31-
q_node = src_partition.input_nodes[0].args[0]
32-
quant_attrs = get_quant_attrs(self.edge_program, q_node)
33-
for n in src_partition.nodes:
34-
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
35-
3626
def _annotate_stack(self, graph_module: torch.fx.GraphModule):
3727
partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"])
3828
for _, src_partitions in partitions.items():
@@ -46,7 +36,6 @@ def _annotate_stack(self, graph_module: torch.fx.GraphModule):
4636
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
4737

4838
def call(self, graph_module: torch.fx.GraphModule):
49-
self._annotate_unbind(graph_module)
5039
self._annotate_stack(graph_module)
5140
graph_module.recompile()
5241
return PassResult(graph_module, True)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
10+
11+
from .utils import dq_ops, get_quant_attrs
12+
13+
14+
class AnnotateUnbind(ExportPass):
15+
"""
16+
Add "quant_attrs" to graph nodes' meta from the QDQ information
17+
generated after quantization process.
18+
"""
19+
20+
decomp_ops = [torch.ops.aten.unbind.int]
21+
22+
def __init__(self, edge_program: torch.export.ExportedProgram):
23+
super(AnnotateUnbind, self).__init__()
24+
self.edge_program = edge_program
25+
26+
def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
27+
partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"])
28+
for _, src_partitions in partitions.items():
29+
for src_partition in src_partitions:
30+
if src_partition.input_nodes[0].target in dq_ops:
31+
q_node = src_partition.input_nodes[0].args[0]
32+
quant_attrs = get_quant_attrs(self.edge_program, q_node)
33+
for n in src_partition.nodes:
34+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
35+
36+
def call(self, graph_module: torch.fx.GraphModule):
37+
self._annotate_unbind(graph_module)
38+
graph_module.recompile()
39+
return PassResult(graph_module, True)

backends/qualcomm/_passes/build_quant_io.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,17 @@ def _make_spec(self, x):
2727
return None
2828

2929
def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
30-
# forcely update delegate node's meta['spec'] to get correct output
30+
# Forcedly update delegate node's meta['spec'] to get correct output
3131
# tensor size in runtime
3232
call_delegate = [
3333
node
3434
for node in graph_module.graph.nodes
3535
if node.op == "call_function" and node.name == "executorch_call_delegate"
3636
]
3737
assert len(call_delegate) == 1
38-
spec = []
3938
for n in graph_module.graph.nodes:
4039
if QCOM_QUANTIZED_IO in n.meta:
4140
n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])
42-
if n.op == "call_function" and "getitem" in n.name:
43-
fake_tensor = n.meta["val"]
44-
if QCOM_QUANTIZED_IO in n.meta:
45-
fake_tensor = fake_tensor.to(dtype=n.meta[QCOM_QUANTIZED_IO])
46-
spec.append(self._make_spec(fake_tensor))
47-
48-
call_delegate[0].meta["spec"] = tuple(spec)
4941

5042
def call(self, graph_module: torch.fx.GraphModule):
5143
self._build(graph_module)

backends/qualcomm/_passes/constant_i64_to_i32.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

0 commit comments

Comments
 (0)