Skip to content

Commit e6802eb

Browse files
committed
Qualcomm AI Engine Direct - GA DIT
Summary: - Add DIT example script - Use HistogramObserver as act_observer to resolve accuracy issue - Add the test for DIT - Support UpsampleBicubic - Remove unused pass convert_upsample_bicubic2d.py
1 parent b805f17 commit e6802eb

17 files changed

+357
-56
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1111
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1212
from .convert_square_to_pow import ConvertSquareToPow
13-
from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear
1413
from .decompose_any import DecomposeAny
1514
from .decompose_cdist import DecomposeCDist
1615
from .decompose_einsum import DecomposeEinsum
@@ -44,7 +43,6 @@
4443
ConvertBmmToMatmul,
4544
ConvertConv1dToConv2d,
4645
ConvertSquareToPow,
47-
ConvertUpsampleBicubicWithBilinear,
4846
DecomposeAny,
4947
DecomposeCDist,
5048
DecomposeEinsum,

backends/qualcomm/_passes/convert_upsample_bicubic2d.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class LayoutTransform(ExportPass):
3838
exir_ops.edge.aten.native_group_norm.default,
3939
exir_ops.edge.aten.pixel_shuffle.default,
4040
exir_ops.edge.aten.pixel_unshuffle.default,
41+
exir_ops.edge.aten.upsample_bicubic2d.default,
42+
exir_ops.edge.aten.upsample_bicubic2d.vec,
4143
exir_ops.edge.aten.upsample_bilinear2d.default,
4244
exir_ops.edge.aten.upsample_bilinear2d.vec,
4345
exir_ops.edge.aten.upsample_nearest2d.default,

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
ConvertBmmToMatmul,
1616
ConvertConv1dToConv2d,
1717
ConvertSquareToPow,
18-
ConvertUpsampleBicubicWithBilinear,
1918
DecomposeAny,
2019
DecomposeCDist,
2120
DecomposeEinsum,
@@ -78,7 +77,6 @@ def get_capture_program_passes():
7877
(AnnotateUnbind, True),
7978
(ConvertBmmToMatmul, True),
8079
(ConvertConv1dToConv2d, True),
81-
(ConvertUpsampleBicubicWithBilinear, False),
8280
(DecomposeAny, True),
8381
(ExpandBroadcastTensorShape, False),
8482
(FixedLinearKeepDim, True),

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def get_passes_dependency_for_capture_program():
7878
AnnotateUnbind,
7979
ConvertBmmToMatmul,
8080
ConvertConv1dToConv2d,
81-
ConvertUpsampleBicubicWithBilinear,
8281
DecomposeAny,
8382
DecomposeLinalgVectorNorm,
8483
ExpandBroadcastTensorShape,
@@ -97,19 +96,18 @@ def get_passes_dependency_for_capture_program():
9796
AnnotateQuantAttrs: [
9897
RecomposePixelUnshuffle,
9998
ConvertBmmToMatmul,
100-
ConvertUpsampleBicubicWithBilinear,
10199
RemoveRedundancy,
102100
],
103101
AnnotateStack: [RemoveRedundancy],
104102
AnnotateUnbind: [RemoveRedundancy],
105103
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
106-
ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy],
104+
ConvertConv1dToConv2d: [FoldQDQ],
107105
DecomposeAny: [RemoveRedundancy],
108106
DecomposeLinalgVectorNorm: [RemoveRedundancy],
109107
ExpandBroadcastTensorShape: [FoldQDQ],
110108
FixedLinearKeepDim: [FoldQDQ],
111109
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
112-
I64toI32: [ConvertUpsampleBicubicWithBilinear, RemoveRedundancy],
110+
I64toI32: [RemoveRedundancy],
113111
LayoutTransform: [
114112
AnnotateQuantAttrs,
115113
ConvertConv1dToConv2d,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
op_relu,
6666
op_repeat,
6767
op_reshape,
68+
op_resize,
6869
op_rms_norm,
6970
op_rsqrt,
7071
op_scalar_tensor,
@@ -153,6 +154,7 @@
153154
op_relu,
154155
op_repeat,
155156
op_reshape,
157+
op_resize,
156158
op_rms_norm,
157159
op_rsqrt,
158160
op_scalar_tensor,
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import numpy as np
10+
import torch
11+
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .qnn_constants import OpResize, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class Resize(NodeVisitor):
20+
# Because QNN support ResizeBilinear and ResizeNearestNeighbor, only bicubic need to be handled in resize op
21+
target = ["aten.upsample_bicubic2d.vec"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
input_node = self.get_node(node.args[0])
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
align_corners = cast(bool, node.args[2])
41+
transformation_mode = np.uint32(2) if align_corners else np.uint32(1)
42+
# This builder supports only bicubic resize.
43+
interpolation_mode = np.uint32(2)
44+
cubic_coeff = np.float32(-0.75)
45+
46+
output_tensor = self.get_tensor(node, node)
47+
output_tensor_wrapper = self.define_tensor(
48+
node,
49+
node,
50+
output_tensor,
51+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
52+
nodes_to_wrappers,
53+
)
54+
resize_op = PyQnnWrapper.PyQnnOpWrapper(
55+
node.name,
56+
QNN_OP_PACKAGE_NAME_QTI_AISW,
57+
OpResize.op_name,
58+
)
59+
resize_op.AddInputTensors([input_tensor_wrapper])
60+
resize_op.AddOutputTensors([output_tensor_wrapper])
61+
62+
resize_op.AddScalarParam(
63+
OpResize.param_exclude_outside,
64+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
65+
{QCOM_DATA: False},
66+
)
67+
resize_op.AddScalarParam(
68+
OpResize.param_transformation_mode,
69+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
70+
{QCOM_DATA: transformation_mode},
71+
)
72+
73+
resize_op.AddScalarParam(
74+
OpResize.param_interpolation_mode,
75+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
76+
{QCOM_DATA: interpolation_mode},
77+
)
78+
resize_op.AddScalarParam(
79+
OpResize.param_cubic_coeff,
80+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
81+
{QCOM_DATA: cubic_coeff},
82+
)
83+
84+
return resize_op

backends/qualcomm/builders/op_upsample_bilinear2d.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ def define_node(
4545
nodes_to_wrappers,
4646
)
4747

48-
reisze_bilinear_op = PyQnnWrapper.PyQnnOpWrapper(
48+
resize_bilinear_op = PyQnnWrapper.PyQnnOpWrapper(
4949
node.name,
5050
QNN_OP_PACKAGE_NAME_QTI_AISW,
5151
OpResizeBilinear.op_name,
5252
)
53-
reisze_bilinear_op.AddInputTensors([input_tensor_wrapper])
54-
reisze_bilinear_op.AddOutputTensors([output_tensor_wrapper])
53+
resize_bilinear_op.AddInputTensors([input_tensor_wrapper])
54+
resize_bilinear_op.AddOutputTensors([output_tensor_wrapper])
5555

56-
reisze_bilinear_op.AddScalarParam(
56+
resize_bilinear_op.AddScalarParam(
5757
OpResizeBilinear.param_align_corners,
5858
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
5959
{QCOM_DATA: node.args[2]},
6060
)
61-
reisze_bilinear_op.AddScalarParam(
61+
resize_bilinear_op.AddScalarParam(
6262
OpResizeBilinear.param_half_pixel_centers,
6363
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
6464
{QCOM_DATA: not node.args[2]},
6565
)
6666

67-
return reisze_bilinear_op
67+
return resize_bilinear_op

backends/qualcomm/builders/op_upsample_nearest2d.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ def define_node(
4545
nodes_to_wrappers,
4646
)
4747

48-
reisze_nearest_op = PyQnnWrapper.PyQnnOpWrapper(
48+
resize_nearest_op = PyQnnWrapper.PyQnnOpWrapper(
4949
node.name,
5050
QNN_OP_PACKAGE_NAME_QTI_AISW,
5151
OpResizeNearestNeighbor.op_name,
5252
)
53-
reisze_nearest_op.AddInputTensors([input_tensor_wrapper])
54-
reisze_nearest_op.AddOutputTensors([output_tensor_wrapper])
53+
resize_nearest_op.AddInputTensors([input_tensor_wrapper])
54+
resize_nearest_op.AddOutputTensors([output_tensor_wrapper])
5555
# align_corners is guaranteed to be false
56-
reisze_nearest_op.AddScalarParam(
56+
resize_nearest_op.AddScalarParam(
5757
OpResizeNearestNeighbor.param_align_corners,
5858
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
5959
{QCOM_DATA: False},
6060
)
61-
reisze_nearest_op.AddScalarParam(
61+
resize_nearest_op.AddScalarParam(
6262
OpResizeNearestNeighbor.param_half_pixel_centers,
6363
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
6464
{QCOM_DATA: True},
6565
)
6666

67-
return reisze_nearest_op
67+
return resize_nearest_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,16 @@ class OpReshape:
402402
op_name: str = "Reshape"
403403

404404

405+
@dataclass(init=False, frozen=True)
406+
class OpResize:
407+
op_name: str = "Resize"
408+
param_exclude_outside: str = "exclude_outside"
409+
param_transformation_mode: str = "transformation_mode"
410+
param_interpolation_mode: str = "interpolation_mode"
411+
param_nearest_mode: str = "nearest_mode"
412+
param_cubic_coeff: str = "cubic_coeff"
413+
414+
405415
@dataclass(init=False, frozen=True)
406416
class OpResizeBilinear:
407417
op_name: str = "ResizeBilinear"

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.slice_scatter.default,
1515
exir_ops.edge.aten.copy.default,
16-
exir_ops.edge.aten.upsample_bicubic2d.vec,
1716
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1817
]
1918

backends/qualcomm/quantizer/annotators.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,13 @@ def annotate_upsample_bilinear2d(
514514
annotate_single_in_single_out(node, quantization_config)
515515

516516

517+
@register_annotator([torch.ops.aten.upsample_bicubic2d.vec])
518+
def annotate_upsample_upsample_bicubic2d(
519+
node: Node, quantization_config: QuantizationConfig
520+
) -> None:
521+
annotate_single_in_single_out(node, quantization_config)
522+
523+
517524
@register_annotator([torch.ops.aten.upsample_nearest2d.vec])
518525
def annotate_upsample_nearest2d(
519526
node: Node, quantization_config: QuantizationConfig

backends/qualcomm/tests/models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,23 @@ def forward(self, x):
12521252
return x6
12531253

12541254

1255+
class ResizeBicubic(torch.nn.Module):
1256+
def __init__(self, size, scale_factor, align_corners):
1257+
super().__init__()
1258+
self.align_corners = align_corners
1259+
self.scale_factor = scale_factor
1260+
self.size = size
1261+
1262+
def forward(self, x):
1263+
return torch.nn.functional.interpolate(
1264+
x,
1265+
size=self.size,
1266+
scale_factor=self.scale_factor,
1267+
mode="bicubic",
1268+
align_corners=self.align_corners,
1269+
)
1270+
1271+
12551272
class ResizeBilinear2D(torch.nn.Module):
12561273
def __init__(self):
12571274
super().__init__()

0 commit comments

Comments
 (0)