Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - fix conv2d to meet QNN constraint
Browse files Browse the repository at this point in the history
Differential Revision: D60967580

Pull Request resolved: #4560
  • Loading branch information
haowhsu-quic authored Aug 12, 2024
1 parent 99e1ae1 commit e800626
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 63 deletions.
33 changes: 20 additions & 13 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@
import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS,
QCOM_AXIS_ORDER,
QCOM_BITWIDTH,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_OFFSET,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
QCOM_SCALE,
QCOM_SCALE_OFFSET,
QCOM_SCALES,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)

Expand Down Expand Up @@ -125,16 +132,16 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
"convolution" in user_0.target.__name__
and list(node.users)[0].args[1] == node
):
quant_config["axis"] = 3
quant_config[QCOM_AXIS] = 3

else:
quant_config["axis"] = quant_attrs["axis"]
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]

quant_config[QCOM_SCALE_OFFSET] = scale_offset
# special case for 4 bits
if (
quant_config["dtype"] == torch.int8
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
quant_config[QCOM_DTYPE] == torch.int8
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
):
quant_config[QCOM_BITWIDTH] = 4
return (
Expand All @@ -149,11 +156,11 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
def make_qnn_per_tensor_config(self, quant_attrs: Dict):
quant_config = copy.deepcopy(quant_attrs)
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
quant_config["offset"] = -quant_attrs["zero_point"]
quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
# special case for 4 bits
if (
quant_config["dtype"] == torch.int8
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
quant_config[QCOM_DTYPE] == torch.int8
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
):
quant_config[QCOM_BITWIDTH] = 4
return (
Expand Down Expand Up @@ -187,15 +194,15 @@ def get_quant_tensor_value(
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
) -> torch.Tensor:
if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
scale = quant_attrs["scale"]
zero_point = quant_attrs["zero_point"]
scale = quant_attrs[QCOM_SCALE]
zero_point = quant_attrs[QCOM_ZERO_POINT]
else: # per channel case
scale = quant_attrs[QCOM_SCALES]
zero_point = quant_attrs[QCOM_ZERO_POINTS]

dtype = quant_configs["dtype"]
dtype = quant_configs[QCOM_DTYPE]

tensor = tensor.div(scale).add(zero_point).round().to(dtype)
tensor = tensor.div(scale + 1e-6).add(zero_point).round().to(dtype)
# Make the backends access data correctly
if quant_configs.get(QCOM_BITWIDTH) == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
Expand Down Expand Up @@ -233,8 +240,8 @@ def get_data_type(
quant_config: Dict,
) -> PyQnnWrapper.Qnn_TensorType_t:
if quant_config:
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config)
return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]]

return QNN_TENSOR_TYPE_MAP[tensor.dtype]

Expand Down
86 changes: 62 additions & 24 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
from executorch.backends.qualcomm.utils.constants import (
QCOM_DATA,
QCOM_DTYPE,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_SCALE,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import (
Expand Down Expand Up @@ -85,6 +94,52 @@ def _add_conv_op_parameter(

return conv_op

def _get_bias_tensor(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
num_output_channel: int,
) -> PyQnnWrapper.PyQnnOpWrapper:
# build dummy node if bias is not given
bias_node = (
node.args[2]
if node.args[2] is not None
else torch.fx.Node(
node.graph,
node.name + "_runtime_bias",
"call_function",
exir_ops.edge.aten.full.default,
(), # args
{}, # kwargs
)
)
# zeros tensor to meet HTP constraint if bias is not given
bias_tensor = (
get_parameter(bias_node, self.edge_program)
if node.args[2] is not None
else torch.zeros(num_output_channel)
)
# insert quant attribute to meet HTP constraint if bias is not given
if (
node.args[2] is None
and (bias_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS)) is not None
):
quant_attrs = bias_quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
quant_attrs[QCOM_SCALE] = 0
quant_attrs[QCOM_DTYPE] = torch.int32
quant_attrs[QCOM_QUANT_MAX] = torch.iinfo(torch.int32).max
quant_attrs[QCOM_QUANT_MIN] = torch.iinfo(torch.int32).min + 1
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs

return self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)

def _define_conv1d(
self,
node: torch.fx.Node,
Expand Down Expand Up @@ -149,17 +204,9 @@ def _define_conv1d(
is_input_tensor=False,
)
conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper]
if node.args[2] is not None:
bias_node = node.args[2]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
conv_input_tensors.append(bias_tensor_wrapper)
conv_input_tensors.append(
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
)

stride = [1] + cast(List[int], node.args[3])
padding = [0] + cast(List[int], node.args[4])
Expand Down Expand Up @@ -265,18 +312,9 @@ def define_node(
is_input_tensor=False,
)
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]

if node.args[2] is not None:
bias_node = node.args[2]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
conv_input_tensors.append(bias_tensor_wrapper)
conv_input_tensors.append(
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
Expand Down
10 changes: 7 additions & 3 deletions backends/qualcomm/builders/op_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS_ORDER,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_SCALE,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -77,10 +81,10 @@ def define_node(
)
if pow_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = pow_quant_attrs.copy()
quant_range = quant_attrs["quant_max"] - quant_attrs["quant_min"]
quant_range = quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN]
# coeff is guaranteed to be positive
quant_attrs["zero_point"] = 0
quant_attrs["scale"] = coeff / quant_range
quant_attrs[QCOM_ZERO_POINT] = 0
quant_attrs[QCOM_SCALE] = coeff / quant_range
scalar_node.meta[QCOM_QUANT_ATTRS] = quant_attrs

scalar_tensor_wrapper = self.define_tensor(
Expand Down
16 changes: 8 additions & 8 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,22 @@ def example_inputs(self):


class Conv1dSequential(torch.nn.Module):
def __init__(self):
def __init__(self, bias=True):
super().__init__()
self.first = torch.nn.Conv1d(
in_channels=1,
out_channels=3,
kernel_size=(3),
padding=1,
bias=True,
bias=bias,
)

self.second = torch.nn.Conv1d(
in_channels=3,
out_channels=2,
kernel_size=(3),
padding=1,
bias=True,
bias=bias,
)

def forward(self, x):
Expand Down Expand Up @@ -315,36 +315,36 @@ def forward(self, x):


class Conv2dSequential(torch.nn.Module):
def __init__(self):
def __init__(self, bias=True):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=1,
out_channels=3,
kernel_size=(3, 3),
padding=1,
bias=True,
bias=bias,
)
self.second = torch.nn.Conv2d(
in_channels=3,
out_channels=2,
kernel_size=(3, 3),
padding=1,
bias=True,
bias=bias,
)

def forward(self, x):
return self.second(self.first(x))


class Conv2dSingle(torch.nn.Module):
def __init__(self):
def __init__(self, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=1,
out_channels=3,
kernel_size=(3, 3),
padding=1,
bias=True,
bias=bias,
)

def forward(self, x):
Expand Down
40 changes: 25 additions & 15 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,18 @@ def test_qnn_backend_clamp(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv1d(self):
module = Conv1dSequential() # noqa: F405
modules = [Conv1dSequential(), Conv1dSequential(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3]),)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d(self):
module = Conv2dSequential() # noqa: F405
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_element_wise_add(self):
test_comb = [
Expand Down Expand Up @@ -597,12 +601,14 @@ def setUp(self):
)

def test_qnn_backend_16a4w_conv2d(self):
module = Conv2dSingle() # noqa: F405
modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
module = self.get_qdq_module(
module, sample_input, quant_dtype=QuantDtype.use_16a4w
)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(
module, sample_input, quant_dtype=QuantDtype.use_16a4w
)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_16a4w_linear(self):
module = Linear() # noqa: F405
Expand Down Expand Up @@ -683,16 +689,20 @@ def test_qnn_backend_clamp(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv1d(self):
module = Conv1dSequential() # noqa: F405
modules = [Conv1dSequential(), Conv1dSequential(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3]),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d(self):
module = Conv2dSequential() # noqa: F405
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_element_wise_add(self):
test_comb = [
Expand Down
7 changes: 7 additions & 0 deletions backends/qualcomm/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@
# Qualcomm specific key

# constants in backends/qualcomm/passes & backends/qualcomm/builders
QCOM_AXIS = "axis"
QCOM_AXIS_ORDER = "axis_order"
QCOM_BITWIDTH = "bitwidth"
QCOM_DATA = "data"
QCOM_DTYPE = "dtype"
QCOM_ENCODING = "encoding"
QCOM_INSERTED_PERMUTE = "qnn_permute"
QCOM_OFFSET = "offset"
QCOM_QUANTIZED_IO = "q_tensor_io"
QCOM_QUANT_ATTRS = "quant_attrs"
QCOM_QUANT_MIN = "quant_min"
QCOM_QUANT_MAX = "quant_max"
QCOM_REQUANTIZE = "requantize"
QCOM_SCALE = "scale"
QCOM_SCALES = "scales"
QCOM_SCALE_OFFSET = "scale_offset"
QCOM_ZERO_POINT = "zero_point"
QCOM_ZERO_POINTS = "zero_points"

# constants in backends/qualcomm/tests
Expand Down

0 comments on commit e800626

Please sign in to comment.