Skip to content

Commit e800626

Browse files
authored
Qualcomm AI Engine Direct - fix conv2d to meet QNN constraint
Differential Revision: D60967580 Pull Request resolved: #4560
1 parent 99e1ae1 commit e800626

File tree

6 files changed

+129
-63
lines changed

6 files changed

+129
-63
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,20 @@
1212
import numpy as np
1313
import torch
1414
from executorch.backends.qualcomm.utils.constants import (
15+
QCOM_AXIS,
1516
QCOM_AXIS_ORDER,
1617
QCOM_BITWIDTH,
18+
QCOM_DTYPE,
1719
QCOM_ENCODING,
20+
QCOM_OFFSET,
1821
QCOM_QUANT_ATTRS,
22+
QCOM_QUANT_MAX,
23+
QCOM_QUANT_MIN,
1924
QCOM_REQUANTIZE,
25+
QCOM_SCALE,
2026
QCOM_SCALE_OFFSET,
2127
QCOM_SCALES,
28+
QCOM_ZERO_POINT,
2229
QCOM_ZERO_POINTS,
2330
)
2431

@@ -125,16 +132,16 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
125132
"convolution" in user_0.target.__name__
126133
and list(node.users)[0].args[1] == node
127134
):
128-
quant_config["axis"] = 3
135+
quant_config[QCOM_AXIS] = 3
129136

130137
else:
131-
quant_config["axis"] = quant_attrs["axis"]
138+
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
132139

133140
quant_config[QCOM_SCALE_OFFSET] = scale_offset
134141
# special case for 4 bits
135142
if (
136-
quant_config["dtype"] == torch.int8
137-
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
143+
quant_config[QCOM_DTYPE] == torch.int8
144+
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
138145
):
139146
quant_config[QCOM_BITWIDTH] = 4
140147
return (
@@ -149,11 +156,11 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
149156
def make_qnn_per_tensor_config(self, quant_attrs: Dict):
150157
quant_config = copy.deepcopy(quant_attrs)
151158
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
152-
quant_config["offset"] = -quant_attrs["zero_point"]
159+
quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
153160
# special case for 4 bits
154161
if (
155-
quant_config["dtype"] == torch.int8
156-
and quant_config["quant_max"] - quant_config["quant_min"] <= 15
162+
quant_config[QCOM_DTYPE] == torch.int8
163+
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
157164
):
158165
quant_config[QCOM_BITWIDTH] = 4
159166
return (
@@ -187,15 +194,15 @@ def get_quant_tensor_value(
187194
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
188195
) -> torch.Tensor:
189196
if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
190-
scale = quant_attrs["scale"]
191-
zero_point = quant_attrs["zero_point"]
197+
scale = quant_attrs[QCOM_SCALE]
198+
zero_point = quant_attrs[QCOM_ZERO_POINT]
192199
else: # per channel case
193200
scale = quant_attrs[QCOM_SCALES]
194201
zero_point = quant_attrs[QCOM_ZERO_POINTS]
195202

196-
dtype = quant_configs["dtype"]
203+
dtype = quant_configs[QCOM_DTYPE]
197204

198-
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
205+
tensor = tensor.div(scale + 1e-6).add(zero_point).round().to(dtype)
199206
# Make the backends access data correctly
200207
if quant_configs.get(QCOM_BITWIDTH) == 4:
201208
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
@@ -233,8 +240,8 @@ def get_data_type(
233240
quant_config: Dict,
234241
) -> PyQnnWrapper.Qnn_TensorType_t:
235242
if quant_config:
236-
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
237-
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
243+
quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config)
244+
return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]]
238245

239246
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
240247

backends/qualcomm/builders/op_conv2d.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,16 @@
1010

1111
import numpy as np
1212
import torch
13-
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
from executorch.backends.qualcomm.utils.constants import (
14+
QCOM_DATA,
15+
QCOM_DTYPE,
16+
QCOM_QUANT_ATTRS,
17+
QCOM_QUANT_MAX,
18+
QCOM_QUANT_MIN,
19+
QCOM_SCALE,
20+
QCOM_ZERO_POINT,
21+
)
22+
from executorch.exir.dialects._ops import ops as exir_ops
1423

1524
from .node_visitor import NodeVisitor, register_node_visitor
1625
from .qnn_constants import (
@@ -85,6 +94,52 @@ def _add_conv_op_parameter(
8594

8695
return conv_op
8796

97+
def _get_bias_tensor(
98+
self,
99+
node: torch.fx.Node,
100+
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
101+
num_output_channel: int,
102+
) -> PyQnnWrapper.PyQnnOpWrapper:
103+
# build dummy node if bias is not given
104+
bias_node = (
105+
node.args[2]
106+
if node.args[2] is not None
107+
else torch.fx.Node(
108+
node.graph,
109+
node.name + "_runtime_bias",
110+
"call_function",
111+
exir_ops.edge.aten.full.default,
112+
(), # args
113+
{}, # kwargs
114+
)
115+
)
116+
# zeros tensor to meet HTP constraint if bias is not given
117+
bias_tensor = (
118+
get_parameter(bias_node, self.edge_program)
119+
if node.args[2] is not None
120+
else torch.zeros(num_output_channel)
121+
)
122+
# insert quant attribute to meet HTP constraint if bias is not given
123+
if (
124+
node.args[2] is None
125+
and (bias_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS)) is not None
126+
):
127+
quant_attrs = bias_quant_attrs.copy()
128+
quant_attrs[QCOM_ZERO_POINT] = 0
129+
quant_attrs[QCOM_SCALE] = 0
130+
quant_attrs[QCOM_DTYPE] = torch.int32
131+
quant_attrs[QCOM_QUANT_MAX] = torch.iinfo(torch.int32).max
132+
quant_attrs[QCOM_QUANT_MIN] = torch.iinfo(torch.int32).min + 1
133+
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
134+
135+
return self.define_tensor(
136+
bias_node,
137+
bias_tensor,
138+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
139+
nodes_to_wrappers,
140+
is_input_tensor=False,
141+
)
142+
88143
def _define_conv1d(
89144
self,
90145
node: torch.fx.Node,
@@ -149,17 +204,9 @@ def _define_conv1d(
149204
is_input_tensor=False,
150205
)
151206
conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper]
152-
if node.args[2] is not None:
153-
bias_node = node.args[2]
154-
bias_tensor = get_parameter(bias_node, self.edge_program)
155-
bias_tensor_wrapper = self.define_tensor(
156-
bias_node,
157-
bias_tensor,
158-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
159-
nodes_to_wrappers,
160-
is_input_tensor=False,
161-
)
162-
conv_input_tensors.append(bias_tensor_wrapper)
207+
conv_input_tensors.append(
208+
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
209+
)
163210

164211
stride = [1] + cast(List[int], node.args[3])
165212
padding = [0] + cast(List[int], node.args[4])
@@ -265,18 +312,9 @@ def define_node(
265312
is_input_tensor=False,
266313
)
267314
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
268-
269-
if node.args[2] is not None:
270-
bias_node = node.args[2]
271-
bias_tensor = get_parameter(bias_node, self.edge_program)
272-
bias_tensor_wrapper = self.define_tensor(
273-
bias_node,
274-
bias_tensor,
275-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
276-
nodes_to_wrappers,
277-
is_input_tensor=False,
278-
)
279-
conv_input_tensors.append(bias_tensor_wrapper)
315+
conv_input_tensors.append(
316+
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
317+
)
280318

281319
output_tensor = self.get_tensor(node, node)
282320
output_tensor_wrapper = self.define_tensor(

backends/qualcomm/builders/op_prelu.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from executorch.backends.qualcomm.utils.constants import (
1212
QCOM_AXIS_ORDER,
1313
QCOM_QUANT_ATTRS,
14+
QCOM_QUANT_MAX,
15+
QCOM_QUANT_MIN,
16+
QCOM_SCALE,
17+
QCOM_ZERO_POINT,
1418
)
1519
from executorch.exir.dialects._ops import ops as exir_ops
1620

@@ -77,10 +81,10 @@ def define_node(
7781
)
7882
if pow_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
7983
quant_attrs = pow_quant_attrs.copy()
80-
quant_range = quant_attrs["quant_max"] - quant_attrs["quant_min"]
84+
quant_range = quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN]
8185
# coeff is guaranteed to be positive
82-
quant_attrs["zero_point"] = 0
83-
quant_attrs["scale"] = coeff / quant_range
86+
quant_attrs[QCOM_ZERO_POINT] = 0
87+
quant_attrs[QCOM_SCALE] = coeff / quant_range
8488
scalar_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
8589

8690
scalar_tensor_wrapper = self.define_tensor(

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,22 +203,22 @@ def example_inputs(self):
203203

204204

205205
class Conv1dSequential(torch.nn.Module):
206-
def __init__(self):
206+
def __init__(self, bias=True):
207207
super().__init__()
208208
self.first = torch.nn.Conv1d(
209209
in_channels=1,
210210
out_channels=3,
211211
kernel_size=(3),
212212
padding=1,
213-
bias=True,
213+
bias=bias,
214214
)
215215

216216
self.second = torch.nn.Conv1d(
217217
in_channels=3,
218218
out_channels=2,
219219
kernel_size=(3),
220220
padding=1,
221-
bias=True,
221+
bias=bias,
222222
)
223223

224224
def forward(self, x):
@@ -315,36 +315,36 @@ def forward(self, x):
315315

316316

317317
class Conv2dSequential(torch.nn.Module):
318-
def __init__(self):
318+
def __init__(self, bias=True):
319319
super().__init__()
320320
self.first = torch.nn.Conv2d(
321321
in_channels=1,
322322
out_channels=3,
323323
kernel_size=(3, 3),
324324
padding=1,
325-
bias=True,
325+
bias=bias,
326326
)
327327
self.second = torch.nn.Conv2d(
328328
in_channels=3,
329329
out_channels=2,
330330
kernel_size=(3, 3),
331331
padding=1,
332-
bias=True,
332+
bias=bias,
333333
)
334334

335335
def forward(self, x):
336336
return self.second(self.first(x))
337337

338338

339339
class Conv2dSingle(torch.nn.Module):
340-
def __init__(self):
340+
def __init__(self, bias=True):
341341
super().__init__()
342342
self.conv = torch.nn.Conv2d(
343343
in_channels=1,
344344
out_channels=3,
345345
kernel_size=(3, 3),
346346
padding=1,
347-
bias=True,
347+
bias=bias,
348348
)
349349

350350
def forward(self, x):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,18 @@ def test_qnn_backend_clamp(self):
109109
self.lower_module_and_test_output(module, sample_input)
110110

111111
def test_qnn_backend_conv1d(self):
112-
module = Conv1dSequential() # noqa: F405
112+
modules = [Conv1dSequential(), Conv1dSequential(bias=False)] # noqa: F405
113113
sample_input = (torch.randn([1, 1, 3]),)
114-
self.lower_module_and_test_output(module, sample_input)
114+
for i, module in enumerate(modules):
115+
with self.subTest(i=i):
116+
self.lower_module_and_test_output(module, sample_input)
115117

116118
def test_qnn_backend_conv2d(self):
117-
module = Conv2dSequential() # noqa: F405
119+
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
118120
sample_input = (torch.randn([1, 1, 3, 3]),)
119-
self.lower_module_and_test_output(module, sample_input)
121+
for i, module in enumerate(modules):
122+
with self.subTest(i=i):
123+
self.lower_module_and_test_output(module, sample_input)
120124

121125
def test_qnn_backend_element_wise_add(self):
122126
test_comb = [
@@ -597,12 +601,14 @@ def setUp(self):
597601
)
598602

599603
def test_qnn_backend_16a4w_conv2d(self):
600-
module = Conv2dSingle() # noqa: F405
604+
modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405
601605
sample_input = (torch.randn([1, 1, 3, 3]),)
602-
module = self.get_qdq_module(
603-
module, sample_input, quant_dtype=QuantDtype.use_16a4w
604-
)
605-
self.lower_module_and_test_output(module, sample_input)
606+
for i, module in enumerate(modules):
607+
with self.subTest(i=i):
608+
module = self.get_qdq_module(
609+
module, sample_input, quant_dtype=QuantDtype.use_16a4w
610+
)
611+
self.lower_module_and_test_output(module, sample_input)
606612

607613
def test_qnn_backend_16a4w_linear(self):
608614
module = Linear() # noqa: F405
@@ -683,16 +689,20 @@ def test_qnn_backend_clamp(self):
683689
self.lower_module_and_test_output(module, sample_input)
684690

685691
def test_qnn_backend_conv1d(self):
686-
module = Conv1dSequential() # noqa: F405
692+
modules = [Conv1dSequential(), Conv1dSequential(bias=False)] # noqa: F405
687693
sample_input = (torch.randn([1, 1, 3]),)
688-
module = self.get_qdq_module(module, sample_input)
689-
self.lower_module_and_test_output(module, sample_input)
694+
for i, module in enumerate(modules):
695+
with self.subTest(i=i):
696+
module = self.get_qdq_module(module, sample_input)
697+
self.lower_module_and_test_output(module, sample_input)
690698

691699
def test_qnn_backend_conv2d(self):
692-
module = Conv2dSequential() # noqa: F405
700+
modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405
693701
sample_input = (torch.randn([1, 1, 3, 3]),)
694-
module = self.get_qdq_module(module, sample_input)
695-
self.lower_module_and_test_output(module, sample_input)
702+
for i, module in enumerate(modules):
703+
with self.subTest(i=i):
704+
module = self.get_qdq_module(module, sample_input)
705+
self.lower_module_and_test_output(module, sample_input)
696706

697707
def test_qnn_backend_element_wise_add(self):
698708
test_comb = [

backends/qualcomm/utils/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,23 @@
77
# Qualcomm specific key
88

99
# constants in backends/qualcomm/passes & backends/qualcomm/builders
10+
QCOM_AXIS = "axis"
1011
QCOM_AXIS_ORDER = "axis_order"
1112
QCOM_BITWIDTH = "bitwidth"
1213
QCOM_DATA = "data"
14+
QCOM_DTYPE = "dtype"
1315
QCOM_ENCODING = "encoding"
1416
QCOM_INSERTED_PERMUTE = "qnn_permute"
17+
QCOM_OFFSET = "offset"
1518
QCOM_QUANTIZED_IO = "q_tensor_io"
1619
QCOM_QUANT_ATTRS = "quant_attrs"
20+
QCOM_QUANT_MIN = "quant_min"
21+
QCOM_QUANT_MAX = "quant_max"
1722
QCOM_REQUANTIZE = "requantize"
23+
QCOM_SCALE = "scale"
1824
QCOM_SCALES = "scales"
1925
QCOM_SCALE_OFFSET = "scale_offset"
26+
QCOM_ZERO_POINT = "zero_point"
2027
QCOM_ZERO_POINTS = "zero_points"
2128

2229
# constants in backends/qualcomm/tests

0 commit comments

Comments
 (0)