Skip to content

Qualcomm AI Engine Direct - Add index and index_put op #4481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
op_hardsigmoid,
op_hardswish,
op_hardtanh,
op_index,
op_index_put,
op_layer_norm,
op_linear,
op_log_softmax,
Expand Down Expand Up @@ -75,6 +77,8 @@
op_hardswish,
op_hardtanh,
op_hardsigmoid,
op_index,
op_index_put,
op_layer_norm,
op_linear,
op_log_softmax,
Expand Down
83 changes: 83 additions & 0 deletions backends/qualcomm/builders/op_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Index(NodeVisitor):
# schema = aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
target = ["aten.index.Tensor"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

if len(node.args[1]) > 1:
# TODO consider to implement it in a recursive way.
raise NotImplementedError("Not support tuple of tensor.")

indices_node = node.args[1][0]
indices_tensor = self.get_tensor(indices_node, node).to(torch.int32)
assert indices_tensor.size(0) != 0, "Not support empty indices list"

indices_tensor_wrapper = self.define_tensor(
indices_node,
indices_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper]

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)
gather_output_tensors = [output_tensor_wrapper]

gather_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpGather.op_name,
)
gather_op.AddInputTensors(gather_input_tensors)
gather_op.AddOutputTensors(gather_output_tensors)

# If support tuple of tensor, need to refine it based on len
gather_op.AddScalarParam(
OpGather.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
{"data": np.int32(0)},
)

return gather_op
83 changes: 83 additions & 0 deletions backends/qualcomm/builders/op_index_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class IndexPutVisitor(NodeVisitor):
target = ["aten.index_put.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
indicies_node = node.args[1]
indices_list = [
self.get_tensor(idx, idx) for idx in indicies_node if idx is not None
]

# Unpack the tuple
indices_unpacked = [torch.flatten(idx) for idx in indices_list]

# Convert to 2-D tensor
indices_qnn = torch.cat(indices_unpacked).unsqueeze(0)
indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)]
# TODO consider to write a pass to combine to one input tensor for indices
assert len(indice_node) == 1, "Not support mutilple indices tensor"

indices_tensor_wrapper = self.define_tensor(
indice_node[0],
indices_qnn,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
value_node = node.args[2]

value_tensor = self.get_tensor(value_node, node)

value_tensor_wrapper = self.define_tensor(
value_node,
value_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

index_put_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpScatterNd.op_name,
)
index_put_op.AddInputTensors(
[input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper]
)
index_put_op.AddOutputTensors([output_tensor_wrapper])

return index_put_op
34 changes: 23 additions & 11 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,6 @@ class OpExpandDims:
param_axis: str = "axis"


@dataclass(init=False, frozen=True)
class OpReduceSum:
op_name: str = "ReduceSum"
param_axes: str = "axes"
param_keep_dims: str = "keep_dims"


@dataclass(init=False, frozen=True)
class OpFullyConnected:
op_name: str = "FullyConnected"
Expand All @@ -144,13 +137,14 @@ class OpGather:


@dataclass(init=False, frozen=True)
class OpGelu:
op_name: str = "Gelu"
class OpGatherND:
op_name: str = "GatherNd"
param_batch_dims: str = "batch_dims"


@dataclass(init=False, frozen=True)
class OpSqrt:
op_name: str = "ElementWiseSquareRoot"
class OpGelu:
op_name: str = "Gelu"


@dataclass(init=False, frozen=True)
Expand Down Expand Up @@ -246,6 +240,13 @@ class OpReduceMean:
param_keep_dims: str = "keep_dims"


@dataclass(init=False, frozen=True)
class OpReduceSum:
op_name: str = "ReduceSum"
param_axes: str = "axes"
param_keep_dims: str = "keep_dims"


@dataclass(init=False, frozen=True)
class OpRelu:
op_name: str = "Relu"
Expand Down Expand Up @@ -277,6 +278,12 @@ class OpResizeNearestNeighbor:
param_half_pixel_centers: str = "half_pixel_centers"


@dataclass(init=False, frozen=True)
class OpScatterNd:
op_name: str = "ScatterNd"
param_reduction: str = "reduction"


@dataclass(init=False, frozen=True)
class OpSigmoid:
op_name: str = "Sigmoid"
Expand Down Expand Up @@ -307,6 +314,11 @@ class OpSplit:
param_split_index: str = "split_index"


@dataclass(init=False, frozen=True)
class OpSqrt:
op_name: str = "ElementWiseSquareRoot"


@dataclass(init=False, frozen=True)
class OpSqueeze:
op_name: str = "Squeeze"
Expand Down
3 changes: 1 addition & 2 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.index_put.default,
exir_ops.edge.aten.copy.default,
]

allow_list_operator = [
Expand Down
32 changes: 32 additions & 0 deletions backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,38 @@ def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> N
)


@register_annotator([torch.ops.aten.index.Tensor])
def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
input_qspec_map = {}
input = node.args[0]
input_qspec_map[input] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((input, node)),
_annotated=True,
)


@register_annotator(
[torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default]
)
def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
input = node.args[0]
value = node.args[2]

input_qspec_map = {}
input_qspec_map[input] = quantization_config.input_activation
input_qspec_map[value] = SharedQuantizationSpec((input, node))

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((input, node)),
_annotated=True,
)


@register_annotator([torch.ops.aten.expand.default])
def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
Expand Down
23 changes: 23 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,29 @@ def forward(self, x):
return self.hardtanh(x)


class Index(torch.nn.Module):
def __init__(self):
super().__init__()
self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]])
self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]])

def forward(self, x):
return x[self.idx0] + x[self.idx1]


class IndexPut(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"k_cache",
torch.zeros((1, 1024, 12, 64), dtype=torch.float32),
)

def forward(self, input_pos, k_val):
k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val)
return k_out


class LayerNorm(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
28 changes: 28 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,19 @@ def test_qnn_backend_hardtanh(self):
sample_input = (torch.randn([2, 5, 1, 3]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_index(self):
module = Index() # noqa: F405
sample_input = (torch.randn([8, 172, 64]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_index_put(self):
module = IndexPut() # noqa: F405
sample_input = (
torch.tensor([2], dtype=torch.int32),
torch.randn([1, 1, 12, 64]),
)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_interpolate_bilinear_2d(self):
module = ResizeBilinear2D() # noqa: F405
sample_input = (torch.randn(2, 3, 4, 5),)
Expand Down Expand Up @@ -827,6 +840,21 @@ def test_qnn_backend_hardtanh(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_index(self):
module = Index() # noqa: F405
sample_input = (torch.randn([8, 172, 64]),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_index_put(self):
module = IndexPut() # noqa: F405
sample_input = (
torch.tensor([2], dtype=torch.int32),
torch.randn([1, 1, 12, 64]),
)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_interpolate_bilinear_2d(self):
module = ResizeBilinear2D() # noqa: F405
sample_input = (torch.randn(2, 3, 4, 5),)
Expand Down
Loading