Skip to content

Commit 1f614ea

Browse files
Qualcomm AI Engine Direct - oss model enablement (EfficientSAM)
- e2e script for https://github.com/yformer/EfficientSAM - Fastvit breakage fix - Passes order correction - Add support for cum_sum
1 parent b0c2c7c commit 1f614ea

19 files changed

+752
-13
lines changed

backends/qualcomm/_passes/layout_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class LayoutTransform(ExportPass):
5353
exir_ops.edge.aten.ceil.default,
5454
exir_ops.edge.aten.clamp.default,
5555
exir_ops.edge.aten.constant_pad_nd.default,
56+
exir_ops.edge.aten.cumsum.default,
5657
exir_ops.edge.aten.div.Tensor,
5758
exir_ops.edge.aten.eq.Tensor,
5859
exir_ops.edge.aten.full.default,

backends/qualcomm/_passes/lift_constant_scalar_operands.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class TensorOpInfo:
4646
aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False),
4747
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
4848
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True),
49+
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False),
4950
}
5051

5152

backends/qualcomm/_passes/recompose_pixel_unshuffle.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,11 @@ def call(self, graph_module: torch.fx.GraphModule):
4545
continue
4646

4747
view_node = premute_node.args[0]
48-
if any(
49-
[
50-
view_node.op != "call_function",
51-
view_node.target != self.view_target,
52-
len(view_node.args[1]) != 6,
53-
len(premute_node.args[1]) != 6,
54-
]
48+
if (
49+
view_node.op != "call_function"
50+
or view_node.target != self.view_target
51+
or len(view_node.args[1]) != 6
52+
or len(premute_node.args[1]) != 6
5553
):
5654
continue
5755

backends/qualcomm/_passes/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_passes_dependency_for_capture_program():
9393
ConvertToLinear: [RecomposePixelUnshuffle],
9494
DecomposeAny: [RemoveRedundancy],
9595
DecomposeLinalgVectorNorm: [RemoveRedundancy],
96-
ExpandBroadcastTensorShape: [RemoveRedundancy],
96+
ExpandBroadcastTensorShape: [ConstantI64toI32, TensorI64toI32],
9797
FoldQDQ: [AnnotateQuantAttrs, AnnotateDecomposed],
9898
LayoutTransform: [
9999
AnnotateQuantAttrs,

backends/qualcomm/builders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
op_clamp,
2020
op_conv2d,
2121
op_cos,
22+
op_cum_sum,
2223
op_depth_to_space,
2324
op_dequantize,
2425
op_div,
@@ -99,6 +100,7 @@
99100
op_clamp,
100101
op_conv2d,
101102
op_cos,
103+
op_cum_sum,
102104
op_depth_to_space,
103105
op_dequantize,
104106
op_div,

backends/qualcomm/builders/op_cos.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
from typing import Dict
87

98
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .qnn_constants import OpCumulativeSum, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class CumulativeSum(NodeVisitor):
20+
target = ["aten.cumsum.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def get_param(self, node, input_tensor):
26+
dim = node.args[1]
27+
28+
if dim < 0:
29+
dim = dim % len(input_tensor.shape)
30+
if QCOM_AXIS_ORDER in node.meta:
31+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
32+
33+
return cast(np.uint32, dim)
34+
35+
def define_node(
36+
self,
37+
node: torch.fx.Node,
38+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
39+
) -> PyQnnWrapper.PyQnnOpWrapper:
40+
input_node = node.args[0]
41+
input_tensor = self.get_tensor(input_node, node)
42+
input_tensor_wrapper = self.define_tensor(
43+
input_node,
44+
node,
45+
input_tensor,
46+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
47+
nodes_to_wrappers,
48+
)
49+
50+
dim = self.get_param(node, input_tensor)
51+
52+
output_tensor = self.get_tensor(node, node)
53+
output_tensor_wrapper = self.define_tensor(
54+
node,
55+
node,
56+
output_tensor,
57+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
58+
nodes_to_wrappers,
59+
)
60+
61+
cumsum_op = PyQnnWrapper.PyQnnOpWrapper(
62+
node.name,
63+
QNN_OP_PACKAGE_NAME_QTI_AISW,
64+
OpCumulativeSum.op_name,
65+
)
66+
cumsum_op.AddInputTensors([input_tensor_wrapper])
67+
cumsum_op.AddOutputTensors([output_tensor_wrapper])
68+
cumsum_op.AddScalarParam(
69+
OpCumulativeSum.param_axis,
70+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
71+
{QCOM_DATA: dim},
72+
)
73+
cumsum_op.AddScalarParam(
74+
OpCumulativeSum.param_exclusive,
75+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
76+
{QCOM_DATA: False},
77+
)
78+
cumsum_op.AddScalarParam(
79+
OpCumulativeSum.param_reverse,
80+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
81+
{QCOM_DATA: False},
82+
)
83+
84+
return cumsum_op

backends/qualcomm/builders/op_sin.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
from typing import Dict
87

98
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

backends/qualcomm/builders/qnn_constants.py

+8
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ class OpConvert:
5050
op_name: str = "Convert"
5151

5252

53+
@dataclass(init=False, frozen=True)
54+
class OpCumulativeSum:
55+
op_name = "CumulativeSum"
56+
param_axis = "axis"
57+
param_exclusive = "exclusive"
58+
param_reverse = "reverse"
59+
60+
5361
@dataclass(init=False, frozen=True)
5462
class OpDepthToSpace:
5563
op_name: str = "DepthToSpace"

backends/qualcomm/quantizer/annotators.py

+5
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,11 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
925925
)
926926

927927

928+
@register_annotator([torch.ops.aten.cumsum.default])
929+
def annotate_cumsum(node: Node, quantization_config: QuantizationConfig) -> None:
930+
annotate_single_in_single_out(node, quantization_config)
931+
932+
928933
@register_annotator([torch.ops.aten.linear.default])
929934
def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None:
930935
act_node = node.args[0]

backends/qualcomm/tests/models.py

+16
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,14 @@ def forward(self, x):
529529
return torch.cos(x)
530530

531531

532+
class CumSum(torch.nn.Module):
533+
def __init__(self):
534+
super().__init__()
535+
536+
def forward(self, x):
537+
return x.cumsum(dim=0)
538+
539+
532540
class Div(torch.nn.Module):
533541
def __init__(self):
534542
super().__init__()
@@ -1469,3 +1477,11 @@ def __init__(self, pos, neg):
14691477

14701478
def forward(self, x):
14711479
return torch.where(x >= torch.zeros(x.shape), self.pos, self.neg)
1480+
1481+
1482+
class WhereConstantOther(torch.nn.Module):
1483+
def __init__(self):
1484+
super().__init__()
1485+
1486+
def forward(self, x):
1487+
return torch.where(x >= 0, torch.ones(x.shape), 0)

backends/qualcomm/tests/test_qnn_delegate.py

+55
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ def test_qnn_backend_cos(self):
208208
sample_input = (torch.randn(2, 5, 1, 3),)
209209
self.lower_module_and_test_output(module, sample_input)
210210

211+
def test_qnn_backend_cumsum(self):
212+
module = CumSum() # noqa: F405
213+
sample_input = (torch.randn(4),)
214+
self.lower_module_and_test_output(module, sample_input)
215+
211216
def test_qnn_backend_einsum_outer_product(self):
212217
module = EinsumOuterProduct() # noqa: F405
213218
x = torch.randn(5)
@@ -790,10 +795,12 @@ def test_qnn_backend_where(self):
790795
modules = [
791796
Where(), # noqa: F405
792797
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
798+
WhereConstantOther(), # noqa: F405
793799
]
794800
sample_inputs = [
795801
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
796802
(torch.randn(3, 2),),
803+
(torch.randn(3, 2),),
797804
]
798805
for i, module in enumerate(modules):
799806
self.lower_module_and_test_output(module, sample_inputs[i])
@@ -1165,6 +1172,12 @@ def test_qnn_backend_cos(self):
11651172
module = self.get_qdq_module(module, sample_input)
11661173
self.lower_module_and_test_output(module, sample_input)
11671174

1175+
def test_qnn_backend_cumsum(self):
1176+
module = CumSum() # noqa: F405
1177+
sample_input = (torch.randn(4),)
1178+
module = self.get_qdq_module(module, sample_input)
1179+
self.lower_module_and_test_output(module, sample_input)
1180+
11681181
def test_qnn_backend_einsum_outer_product(self):
11691182
module = EinsumOuterProduct() # noqa: F405
11701183
x = torch.randn(5)
@@ -1826,10 +1839,12 @@ def test_qnn_backend_where(self):
18261839
modules = [
18271840
Where(), # noqa: F405
18281841
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
1842+
WhereConstantOther(), # noqa: F405
18291843
]
18301844
sample_inputs = [
18311845
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
18321846
(torch.randn(3, 2),),
1847+
(torch.randn(3, 2),),
18331848
]
18341849
for i, module in enumerate(modules):
18351850
module = self.get_qdq_module(module, sample_inputs[i])
@@ -3421,6 +3436,46 @@ def test_dino_v2(self):
34213436
self.assertGreaterEqual(msg["top_1"], 70)
34223437
self.assertGreaterEqual(msg["top_5"], 85)
34233438

3439+
def test_efficientSAM(self):
3440+
if not self.required_envs(
3441+
[self.image_dataset, self.pretrained_weight, self.oss_repo]
3442+
):
3443+
self.skipTest("missing required envs")
3444+
cmds = [
3445+
"python",
3446+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientSAM.py",
3447+
"--dataset",
3448+
self.image_dataset,
3449+
"--artifact",
3450+
self.artifact_dir,
3451+
"--build_folder",
3452+
self.build_folder,
3453+
"--device",
3454+
self.device,
3455+
"--model",
3456+
self.model,
3457+
"--oss_repo",
3458+
self.oss_repo,
3459+
"--pretrained_weight",
3460+
self.pretrained_weight,
3461+
"--ip",
3462+
self.ip,
3463+
"--port",
3464+
str(self.port),
3465+
]
3466+
if self.host:
3467+
cmds.extend(["--host", self.host])
3468+
3469+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3470+
with Listener((self.ip, self.port)) as listener:
3471+
conn = listener.accept()
3472+
p.communicate()
3473+
msg = json.loads(conn.recv())
3474+
if "Error" in msg:
3475+
self.fail(msg["Error"])
3476+
else:
3477+
self.assertGreaterEqual(msg["MIoU"], 0.55)
3478+
34243479
def test_esrgan(self):
34253480
if not self.required_envs():
34263481
self.skipTest("missing required envs")

backends/qualcomm/tests/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,13 @@ def lower_module_and_test_output(
459459
skip_node_id_set: set = None,
460460
skip_node_op_set: set = None,
461461
dynamic_shapes: Dict = None,
462+
passes_job: collections.OrderedDict = None,
462463
):
463464
qnn_partitioner = QnnPartitioner(
464465
self.compiler_specs, skip_node_id_set, skip_node_op_set
465466
)
466467
delegated_program = capture_program(
467-
module, sample_inputs, dynamic_shapes=dynamic_shapes
468+
module, sample_inputs, dynamic_shapes=dynamic_shapes, passes_job=passes_job
468469
)
469470

470471
# this is needed for the ETRecord as lowering modifies the graph in-place

0 commit comments

Comments
 (0)