Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import absolute_import as _abs
from tvm.relay.expr import Tuple
from tvm.relay.op.nn.util import get_pad_tuple2d
from . import _make

def requantize(data,
Expand Down Expand Up @@ -280,6 +281,9 @@ def conv2d(data,
The computed result.
"""

# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.conv2d(data, kernel,
input_zero_point, kernel_zero_point,
input_scale, kernel_scale,
Expand Down
17 changes: 10 additions & 7 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,17 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero
Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) {
// 1) Pad the input data
auto padded_data = data;
auto pad_h_value = get_const_int(param->padding[0]);
auto pad_w_value = get_const_int(param->padding[1]);
if (pad_h_value != 0 || pad_w_value != 0) {
auto pad_top_value = get_const_int(param->padding[0]);
auto pad_left_value = get_const_int(param->padding[1]);
auto pad_bottom_value = get_const_int(param->padding[2]);
auto pad_right_value = get_const_int(param->padding[3]);
bool do_pad = pad_top_value != 0 || pad_left_value != 0 ||
pad_bottom_value != 0 || pad_right_value != 0;
if (do_pad) {
Array<IndexExpr> pad_n({0, 0});
Array<IndexExpr> pad_c({0, 0});
Array<IndexExpr> pad_h({param->padding[0], param->padding[0]});
Array<IndexExpr> pad_w({param->padding[1], param->padding[1]});
Array<IndexExpr> pad_h({param->padding[0], param->padding[2]});
Array<IndexExpr> pad_w({param->padding[1], param->padding[3]});

Array<Array<IndexExpr>> pad_width;
if (param->data_layout == "NCHW") {
Expand Down Expand Up @@ -336,7 +340,7 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i
*/
Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) {
// Lowering for Term 1
Array<IndexExpr> padding({0, 0});
Array<IndexExpr> padding({0, 0, 0, 0});
return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups,
param->channels, param->kernel_size, param->data_layout, param->kernel_layout,
param->out_layout, param->out_dtype);
Expand Down Expand Up @@ -583,7 +587,6 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* param = attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
// Assertion checks for exisiing support.
CHECK_EQ(param->padding.size(), 2) << "qnn.conv2d only supports 2D padding";
CHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC")
<< "qnn.conv2d supports only NCHW/NHWC input data layout.";
CHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" ||
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,30 @@ def test_padding():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

# Try asymmetric padding
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 4, 3) # HWIO
kernel_dtype = 'uint8'
ref_func, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=3,
input_scale=1.0,
kernel_scale=1.0,
kernel_size=(2, 2),
padding=(1, 1, 2, 2),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)


def test_dilation():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):

Expand Down