Skip to content

Qualcomm AI Engine Direct - op support #8306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.logical_not.default,
exir_ops.edge.aten.lt.Scalar,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten._log_softmax.default,
Expand All @@ -88,6 +89,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.topk.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.where.self,
*q_ops,
*dq_ops,
_operator.getitem,
Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
op_linear,
op_log,
op_log_softmax,
op_logical_not,
op_lt,
op_matmul,
op_max,
Expand Down Expand Up @@ -76,6 +77,7 @@
op_unsqueeze,
op_upsample_bilinear2d,
op_upsample_nearest2d,
op_where,
)

__all__ = [
Expand Down Expand Up @@ -113,6 +115,7 @@
op_le,
op_linear,
op_log,
op_logical_not,
op_log_softmax,
op_lt,
op_matmul,
Expand Down Expand Up @@ -150,4 +153,5 @@
op_unsqueeze,
op_upsample_bilinear2d,
op_upsample_nearest2d,
op_where,
]
55 changes: 55 additions & 0 deletions backends/qualcomm/builders/op_logical_not.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseNot, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Not(NodeVisitor):
target = ["aten.logical_not.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

logical_not_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseNot.op_name,
)
logical_not_op.AddInputTensors([input_tensor_wrapper])
logical_not_op.AddOutputTensors([output_tensor_wrapper])

return logical_not_op
81 changes: 81 additions & 0 deletions backends/qualcomm/builders/op_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseSelect, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Where(NodeVisitor):
target = ["aten.where.self"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
conditional_input_node = node.args[0]
conditional_input_tensor = self.get_tensor(conditional_input_node, node)
conditional_input_tensor_wrapper = self.define_tensor(
conditional_input_node,
node,
conditional_input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

true_input_node = node.args[1]
true_input_tensor = self.get_tensor(true_input_node, node)
true_input_tensor_wrapper = self.define_tensor(
true_input_node,
node,
true_input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

false_input_node = node.args[2]
false_input_tensor = self.get_tensor(false_input_node, node)
false_input_tensor_wrapper = self.define_tensor(
false_input_node,
node,
false_input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

where_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseSelect.op_name,
)
where_op.AddInputTensors(
[
conditional_input_tensor_wrapper,
true_input_tensor_wrapper,
false_input_tensor_wrapper,
]
)
where_op.AddOutputTensors([output_tensor_wrapper])

return where_op
10 changes: 10 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ class OpElementWiseNeuron:
param_beta: str = "beta"


@dataclass(init=False, frozen=True)
class OpElementWiseNot:
op_name: str = "ElementWiseNot"


@dataclass(init=False, frozen=True)
class OpElementWisePower:
op_name: str = "ElementWisePower"
Expand All @@ -173,6 +178,11 @@ class OpElementWiseSin:
op_name: str = "ElementWiseSin"


@dataclass(init=False, frozen=True)
class OpElementWiseSelect:
op_name = "ElementWiseSelect"


@dataclass(init=False, frozen=True)
class OpElementWiseSubtract:
op_name = "ElementWiseSubtract"
Expand Down
3 changes: 0 additions & 3 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from executorch.exir.dialects._ops import ops as exir_ops


not_supported_operator = [
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.full.default,
Expand All @@ -18,8 +17,6 @@

to_be_implemented_operator = [
exir_ops.edge.aten.any.dim,
exir_ops.edge.aten.logical_not.default,
exir_ops.edge.aten.where.self,
]

constant_operator = [
Expand Down
23 changes: 23 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,3 +1070,26 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
output_qspec=quantization_config.output_activation,
_annotated=True,
)


@register_annotator([torch.ops.aten.where.self])
def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
true_input_act = node.args[1]
false_input_act = node.args[2]
if _is_annotated([node]):
return

_annotate_input_qspec_map(
node,
true_input_act,
quantization_config.input_activation,
)

_annotate_input_qspec_map(
node,
false_input_act,
quantization_config.input_activation,
)

_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node])
26 changes: 26 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,14 @@ def forward(self, x):
return torch.log(x)


class LogicalNot(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.logical_not(x > 0)


class LogSoftmax(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1306,3 +1314,21 @@ def forward(self, x, y):
x = x.view(new_shape)
x = x.permute(0, 2, 1, 3)
return torch.matmul(x, y.transpose(-1, -2))


class Where(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y, z):
return torch.where(x >= torch.zeros(x.shape), y, z)


class WhereConstant(torch.nn.Module):
def __init__(self, pos, neg):
super().__init__()
self.register_buffer("pos", pos)
self.register_buffer("neg", neg)

def forward(self, x):
return torch.where(x >= torch.zeros(x.shape), self.pos, self.neg)
36 changes: 36 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ def test_qnn_backend_log(self):
sample_input = (torch.rand([1, 2, 3, 4]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_logical_not(self):
module = LogicalNot() # noqa: F405
sample_input = (torch.rand([1, 2, 3, 4]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_log_softmax(self):
module = LogSoftmax() # noqa: F405
sample_input = (torch.randn([1, 4, 8, 8]),)
Expand Down Expand Up @@ -692,6 +697,18 @@ def test_qnn_backend_view(self):
sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_where(self):
modules = [
Where(), # noqa: F405
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
]
sample_inputs = [
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
(torch.randn(3, 2),),
]
for i, module in enumerate(modules):
self.lower_module_and_test_output(module, sample_inputs[i])


class TestQNNFloatingPointModel(TestQNN):
# TODO: refactor to support different backends
Expand Down Expand Up @@ -1396,6 +1413,12 @@ def test_qnn_backend_log(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_logical_not(self):
module = LogicalNot() # noqa: F405
sample_input = (torch.rand([1, 2, 3, 4]),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_log_softmax(self):
module = LogSoftmax() # noqa: F405
sample_input = (torch.randn([1, 4, 8, 8]),)
Expand Down Expand Up @@ -1609,6 +1632,19 @@ def test_qnn_backend_view(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_where(self):
modules = [
Where(), # noqa: F405
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
]
sample_inputs = [
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
(torch.randn(3, 2),),
]
for i, module in enumerate(modules):
module = self.get_qdq_module(module, sample_inputs[i])
self.lower_module_and_test_output(module, sample_inputs[i])


class TestQNNQuantizedModel(TestQNN):
# TODO: refactor to support different backends
Expand Down
Loading