Skip to content

Qualcomm AI Engine Direct - GA DIT #11093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_square_to_pow import ConvertSquareToPow
from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear
from .decompose_any import DecomposeAny
from .decompose_cdist import DecomposeCDist
from .decompose_einsum import DecomposeEinsum
Expand Down Expand Up @@ -44,7 +43,6 @@
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertSquareToPow,
ConvertUpsampleBicubicWithBilinear,
DecomposeAny,
DecomposeCDist,
DecomposeEinsum,
Expand Down
27 changes: 0 additions & 27 deletions backends/qualcomm/_passes/convert_upsample_bicubic2d.py

This file was deleted.

2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.native_group_norm.default,
exir_ops.edge.aten.pixel_shuffle.default,
exir_ops.edge.aten.pixel_unshuffle.default,
exir_ops.edge.aten.upsample_bicubic2d.default,
exir_ops.edge.aten.upsample_bicubic2d.vec,
exir_ops.edge.aten.upsample_bilinear2d.default,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.upsample_nearest2d.default,
Expand Down
2 changes: 0 additions & 2 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertSquareToPow,
ConvertUpsampleBicubicWithBilinear,
DecomposeAny,
DecomposeCDist,
DecomposeEinsum,
Expand Down Expand Up @@ -78,7 +77,6 @@ def get_capture_program_passes():
(AnnotateUnbind, True),
(ConvertBmmToMatmul, True),
(ConvertConv1dToConv2d, True),
(ConvertUpsampleBicubicWithBilinear, False),
(DecomposeAny, True),
(ExpandBroadcastTensorShape, False),
(FixedLinearKeepDim, True),
Expand Down
5 changes: 1 addition & 4 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def get_passes_dependency_for_capture_program():
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertUpsampleBicubicWithBilinear,
DecomposeAny,
DecomposeLinalgVectorNorm,
ExpandBroadcastTensorShape,
Expand All @@ -97,19 +96,17 @@ def get_passes_dependency_for_capture_program():
AnnotateQuantAttrs: [
RecomposePixelUnshuffle,
ConvertBmmToMatmul,
ConvertUpsampleBicubicWithBilinear,
RemoveRedundancy,
],
AnnotateStack: [RemoveRedundancy],
AnnotateUnbind: [RemoveRedundancy],
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy],
DecomposeAny: [RemoveRedundancy],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
ExpandBroadcastTensorShape: [FoldQDQ],
FixedLinearKeepDim: [FoldQDQ],
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
I64toI32: [ConvertUpsampleBicubicWithBilinear, RemoveRedundancy],
I64toI32: [RemoveRedundancy],
LayoutTransform: [
AnnotateQuantAttrs,
ConvertConv1dToConv2d,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
op_relu,
op_repeat,
op_reshape,
op_resize,
op_rms_norm,
op_rsqrt,
op_scalar_tensor,
Expand Down Expand Up @@ -153,6 +154,7 @@
op_relu,
op_repeat,
op_reshape,
op_resize,
op_rms_norm,
op_rsqrt,
op_scalar_tensor,
Expand Down
84 changes: 84 additions & 0 deletions backends/qualcomm/builders/op_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import numpy as np
import torch

from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpResize, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Resize(NodeVisitor):
# Because QNN support ResizeBilinear and ResizeNearestNeighbor, only bicubic need to be handled in resize op
target = ["aten.upsample_bicubic2d.vec"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
align_corners = cast(bool, node.args[2])
transformation_mode = np.uint32(2) if align_corners else np.uint32(1)
# This builder supports only bicubic resize.
interpolation_mode = np.uint32(2)
cubic_coeff = np.float32(-0.75)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
resize_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpResize.op_name,
)
resize_op.AddInputTensors([input_tensor_wrapper])
resize_op.AddOutputTensors([output_tensor_wrapper])

resize_op.AddScalarParam(
OpResize.param_exclude_outside,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: False},
)
resize_op.AddScalarParam(
OpResize.param_transformation_mode,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: transformation_mode},
)

resize_op.AddScalarParam(
OpResize.param_interpolation_mode,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: interpolation_mode},
)
resize_op.AddScalarParam(
OpResize.param_cubic_coeff,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{QCOM_DATA: cubic_coeff},
)

return resize_op
12 changes: 6 additions & 6 deletions backends/qualcomm/builders/op_upsample_bilinear2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ def define_node(
nodes_to_wrappers,
)

reisze_bilinear_op = PyQnnWrapper.PyQnnOpWrapper(
resize_bilinear_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpResizeBilinear.op_name,
)
reisze_bilinear_op.AddInputTensors([input_tensor_wrapper])
reisze_bilinear_op.AddOutputTensors([output_tensor_wrapper])
resize_bilinear_op.AddInputTensors([input_tensor_wrapper])
resize_bilinear_op.AddOutputTensors([output_tensor_wrapper])

reisze_bilinear_op.AddScalarParam(
resize_bilinear_op.AddScalarParam(
OpResizeBilinear.param_align_corners,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: node.args[2]},
)
reisze_bilinear_op.AddScalarParam(
resize_bilinear_op.AddScalarParam(
OpResizeBilinear.param_half_pixel_centers,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: not node.args[2]},
)

return reisze_bilinear_op
return resize_bilinear_op
12 changes: 6 additions & 6 deletions backends/qualcomm/builders/op_upsample_nearest2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ def define_node(
nodes_to_wrappers,
)

reisze_nearest_op = PyQnnWrapper.PyQnnOpWrapper(
resize_nearest_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpResizeNearestNeighbor.op_name,
)
reisze_nearest_op.AddInputTensors([input_tensor_wrapper])
reisze_nearest_op.AddOutputTensors([output_tensor_wrapper])
resize_nearest_op.AddInputTensors([input_tensor_wrapper])
resize_nearest_op.AddOutputTensors([output_tensor_wrapper])
# align_corners is guaranteed to be false
reisze_nearest_op.AddScalarParam(
resize_nearest_op.AddScalarParam(
OpResizeNearestNeighbor.param_align_corners,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: False},
)
reisze_nearest_op.AddScalarParam(
resize_nearest_op.AddScalarParam(
OpResizeNearestNeighbor.param_half_pixel_centers,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: True},
)

return reisze_nearest_op
return resize_nearest_op
10 changes: 10 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,16 @@ class OpReshape:
op_name: str = "Reshape"


@dataclass(init=False, frozen=True)
class OpResize:
op_name: str = "Resize"
param_exclude_outside: str = "exclude_outside"
param_transformation_mode: str = "transformation_mode"
param_interpolation_mode: str = "interpolation_mode"
param_nearest_mode: str = "nearest_mode"
param_cubic_coeff: str = "cubic_coeff"


@dataclass(init=False, frozen=True)
class OpResizeBilinear:
op_name: str = "ResizeBilinear"
Expand Down
1 change: 0 additions & 1 deletion backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.copy.default,
exir_ops.edge.aten.upsample_bicubic2d.vec,
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
]

Expand Down
7 changes: 7 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,13 @@ def annotate_upsample_bilinear2d(
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.upsample_bicubic2d.vec])
def annotate_upsample_upsample_bicubic2d(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.upsample_nearest2d.vec])
def annotate_upsample_nearest2d(
node: Node, quantization_config: QuantizationConfig
Expand Down
17 changes: 17 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,23 @@ def forward(self, x):
return x6


class ResizeBicubic(torch.nn.Module):
def __init__(self, size, scale_factor, align_corners):
super().__init__()
self.align_corners = align_corners
self.scale_factor = scale_factor
self.size = size

def forward(self, x):
return torch.nn.functional.interpolate(
x,
size=self.size,
scale_factor=self.scale_factor,
mode="bicubic",
align_corners=self.align_corners,
)


class ResizeBilinear2D(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading