Skip to content

[Draft] Qualcomm AI Engine Direct - [WIP] llama2... #3656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -41,11 +40,13 @@
op_skip_ops,
op_slice_copy,
op_softmax,
op_split,
op_sqrt,
op_squeeze,
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand All @@ -57,7 +58,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -87,11 +87,13 @@
op_skip_ops,
op_slice_copy,
op_softmax,
op_split,
op_squeeze,
op_sqrt,
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand Down
40 changes: 17 additions & 23 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

from executorch.exir.dialects._ops import ops as exir_ops

from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
from .utils import (
deduce_dtype,
get_parameter,
is_graph_input,
is_graph_output,
is_parameter,
)


QNN_QUANT_TYPE_MAP = {
Expand Down Expand Up @@ -215,24 +221,9 @@ def get_data_type(
self,
tensor: torch.Tensor,
quant_config: Dict,
is_tensor: bool,
) -> PyQnnWrapper.Qnn_TensorType_t:
if quant_config and is_tensor:
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
unsigned = quant_config["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
if unsigned:
quant_config["dtype"] = torch.uint8
else:
quant_config["dtype"] = torch.int8
elif (
quant_range
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
):
if unsigned:
quant_config["dtype"] = torch.uint16
else:
quant_config["dtype"] = torch.int16
if quant_config:
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
else:
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
Expand Down Expand Up @@ -277,7 +268,7 @@ def define_tensor(
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
is_input_tensor: bool,
node_name: str = None,
is_tensor: bool = True,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Covert torch.Tensor to TensorWrapper
Expand All @@ -293,17 +284,20 @@ def define_tensor(
if node_name is None:
node_name = node.name

if node_name in nodes_to_wrappers:
return nodes_to_wrappers[node_name]
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = node.name
if is_graph_input(node, self.edge_program):
tensor_name = "QnnInput_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
tensor_type = self.get_tensor_type(node, tensor_type)
quant_encoding, quant_configs = self.get_quant_encoding_conf(
node, is_input_tensor
)
dtype = self.get_data_type(tensor, quant_configs, is_tensor)
dtype = self.get_data_type(tensor, quant_configs)
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
tensor_wrapper = PyQnnWrapper.TensorWrapper(
tensor_name,
Expand Down Expand Up @@ -334,7 +328,7 @@ def define_tensor(
tensor.detach().numpy(),
True,
)
nodes_to_wrappers[node_name] = tensor_wrapper
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
return tensor_wrapper

def define_node(
Expand Down
57 changes: 0 additions & 57 deletions backends/qualcomm/builders/op_cast.py

This file was deleted.

2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def define_node(
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
is_input_tensor=True,
)

indices_node = node.args[1]
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/op_skip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,7 @@ def define_node(
raise AssertionError(
f"Invalid number of index for {node.name }: {len(node.args[1])}"
)
nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name)
nodes_to_wrappers[node.name] = {
0: nodes_to_wrappers.get(node.args[0].name).get(node.args[1])
}
return
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def define_node(
ranges = []
for i in range(input_tensor_rank):
if i == dim:
ranges.extend([start, end, 1])
# find step
step = node.args[4] if len(node.args) > 4 else 1
ranges.extend([start, end, step])
else:
ranges.extend([0, input_tensor.shape[i], 1])

Expand Down
85 changes: 85 additions & 0 deletions backends/qualcomm/builders/op_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Split(NodeVisitor):
target = ["aten.split_with_sizes.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
split_input_tensors = [input_tensor_wrapper]

axis = 0 if len(node.args) < 3 else cast(int, node.args[2])
if axis < 0:
axis = axis % len(input_tensor.shape)
if "axis_order" in node.meta:
axis = node.meta["axis_order"].index(axis)

# this is not the general case, only a quick workaround here
index = np.arange(1, input_tensor.shape[axis], dtype=np.uint32)
index_shape = [len(index)]

split_output_tensors = []
for i in range(input_tensor.shape[axis]):
output_tensor = self.get_tensor(node, node, i)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
wrapper_idx=i,
)
split_output_tensors.append(output_tensor_wrapper)

split_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpSplit.op_name,
)
split_op.AddInputTensors(split_input_tensors)
split_op.AddOutputTensors(split_output_tensors)

split_op.AddScalarParam(
OpSplit.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(axis)},
)
split_op.AddTensorParam(
OpSplit.param_split_index,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(index_shape),
index_shape,
index,
True,
)

return split_op
Loading