@@ -177,13 +177,17 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero
177177Expr Conv2DPadInput (const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) {
178178 // 1) Pad the input data
179179 auto padded_data = data;
180- auto pad_h_value = get_const_int (param->padding [0 ]);
181- auto pad_w_value = get_const_int (param->padding [1 ]);
182- if (pad_h_value != 0 || pad_w_value != 0 ) {
180+ auto pad_top_value = get_const_int (param->padding [0 ]);
181+ auto pad_left_value = get_const_int (param->padding [1 ]);
182+ auto pad_bottom_value = get_const_int (param->padding [2 ]);
183+ auto pad_right_value = get_const_int (param->padding [3 ]);
184+ bool do_pad = pad_top_value != 0 || pad_left_value != 0 ||
185+ pad_bottom_value != 0 || pad_right_value != 0 ;
186+ if (do_pad) {
183187 Array<IndexExpr> pad_n ({0 , 0 });
184188 Array<IndexExpr> pad_c ({0 , 0 });
185- Array<IndexExpr> pad_h ({param->padding [0 ], param->padding [0 ]});
186- Array<IndexExpr> pad_w ({param->padding [1 ], param->padding [1 ]});
189+ Array<IndexExpr> pad_h ({param->padding [0 ], param->padding [2 ]});
190+ Array<IndexExpr> pad_w ({param->padding [1 ], param->padding [3 ]});
187191
188192 Array<Array<IndexExpr>> pad_width;
189193 if (param->data_layout == " NCHW" ) {
@@ -336,7 +340,7 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i
336340 */
337341Expr Conv2DFirstTerm (const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) {
338342 // Lowering for Term 1
339- Array<IndexExpr> padding ({0 , 0 });
343+ Array<IndexExpr> padding ({0 , 0 , 0 , 0 });
340344 return Conv2D (padded_data, weight, param->strides , padding, param->dilation , param->groups ,
341345 param->channels , param->kernel_size , param->data_layout , param->kernel_layout ,
342346 param->out_layout , param->out_dtype );
@@ -583,7 +587,6 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
583587 const auto * param = attrs.as <Conv2DAttrs>();
584588 CHECK (param != nullptr );
585589 // Assertion checks for exisiing support.
586- CHECK_EQ (param->padding .size (), 2 ) << " qnn.conv2d only supports 2D padding" ;
587590 CHECK (param->data_layout == " NCHW" || param->data_layout == " NHWC" )
588591 << " qnn.conv2d supports only NCHW/NHWC input data layout." ;
589592 CHECK (param->kernel_layout == " OIHW" || param->kernel_layout == " HWIO" ||
0 commit comments