Skip to content

Commit a84a06c

Browse files
anijain2305Ubuntu
authored andcommitted
[QNN] Support 4D padding. (apache#5036)
* [QNN] Support 4D padding. * Empty commit. Co-authored-by: Ubuntu <ubuntu@ip-172-31-38-96.us-west-2.compute.internal>
1 parent 3987207 commit a84a06c

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import absolute_import as _abs
2121
from tvm.relay.expr import Tuple
22+
from tvm.relay.op.nn.util import get_pad_tuple2d
2223
from . import _make
2324

2425
def requantize(data,
@@ -280,6 +281,9 @@ def conv2d(data,
280281
The computed result.
281282
"""
282283

284+
# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
285+
# convert 2-way padding to 4-way padding
286+
padding = get_pad_tuple2d(padding)
283287
return _make.conv2d(data, kernel,
284288
input_zero_point, kernel_zero_point,
285289
input_scale, kernel_scale,

src/relay/qnn/op/convolution.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,17 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero
177177
Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) {
178178
// 1) Pad the input data
179179
auto padded_data = data;
180-
auto pad_h_value = get_const_int(param->padding[0]);
181-
auto pad_w_value = get_const_int(param->padding[1]);
182-
if (pad_h_value != 0 || pad_w_value != 0) {
180+
auto pad_top_value = get_const_int(param->padding[0]);
181+
auto pad_left_value = get_const_int(param->padding[1]);
182+
auto pad_bottom_value = get_const_int(param->padding[2]);
183+
auto pad_right_value = get_const_int(param->padding[3]);
184+
bool do_pad = pad_top_value != 0 || pad_left_value != 0 ||
185+
pad_bottom_value != 0 || pad_right_value != 0;
186+
if (do_pad) {
183187
Array<IndexExpr> pad_n({0, 0});
184188
Array<IndexExpr> pad_c({0, 0});
185-
Array<IndexExpr> pad_h({param->padding[0], param->padding[0]});
186-
Array<IndexExpr> pad_w({param->padding[1], param->padding[1]});
189+
Array<IndexExpr> pad_h({param->padding[0], param->padding[2]});
190+
Array<IndexExpr> pad_w({param->padding[1], param->padding[3]});
187191

188192
Array<Array<IndexExpr>> pad_width;
189193
if (param->data_layout == "NCHW") {
@@ -336,7 +340,7 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i
336340
*/
337341
Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) {
338342
// Lowering for Term 1
339-
Array<IndexExpr> padding({0, 0});
343+
Array<IndexExpr> padding({0, 0, 0, 0});
340344
return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups,
341345
param->channels, param->kernel_size, param->data_layout, param->kernel_layout,
342346
param->out_layout, param->out_dtype);
@@ -583,7 +587,6 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
583587
const auto* param = attrs.as<Conv2DAttrs>();
584588
CHECK(param != nullptr);
585589
// Assertion checks for exisiing support.
586-
CHECK_EQ(param->padding.size(), 2) << "qnn.conv2d only supports 2D padding";
587590
CHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC")
588591
<< "qnn.conv2d supports only NCHW/NHWC input data layout.";
589592
CHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" ||

tests/python/relay/test_op_qnn_conv2d.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,30 @@ def test_padding():
496496
verify(ref_func, qnn_func, data_shape, data_dtype,
497497
kernel_shape, kernel_dtype)
498498

499+
# Try asymmetric padding
500+
data_shape = (2, 2, 4, 4) # NHWC
501+
data_dtype = 'uint8'
502+
kernel_shape = (2, 2, 4, 3) # HWIO
503+
kernel_dtype = 'uint8'
504+
ref_func, qnn_func = get_funcs(data_shape=data_shape,
505+
data_dtype=data_dtype,
506+
kernel_shape=kernel_shape,
507+
kernel_dtype=kernel_dtype,
508+
input_zero_point=8,
509+
kernel_zero_point=3,
510+
input_scale=1.0,
511+
kernel_scale=1.0,
512+
kernel_size=(2, 2),
513+
padding=(1, 1, 2, 2),
514+
strides=(1, 1),
515+
dilation=(1, 1),
516+
data_layout="NHWC",
517+
kernel_layout="HWIO",
518+
out_dtype="int32")
519+
verify(ref_func, qnn_func, data_shape, data_dtype,
520+
kernel_shape, kernel_dtype)
521+
522+
499523
def test_dilation():
500524
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
501525

0 commit comments

Comments
 (0)