@@ -492,7 +492,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
492492 return tvm::max (x (indices), { dheight, dwidth }); // NOLINT(*)
493493 }, " tensor" , " adaptive_pool_max" );
494494 } else if (pool_type == kAvgPool ) {
495- return tvm::compute (out_shape, [&](const Array<Var>& output) {
495+ auto pool_sum = tvm::compute (out_shape, [&](const Array<Var>& output) {
496496 Array<Expr> indices;
497497 for (const Var& var : output) indices.push_back (var);
498498 auto i_start_h = start_index (output[height_axis], out_height, height);
@@ -505,8 +505,20 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
505505 auto dwidth = tvm::reduce_axis (Range (0 , i_end_w - i_start_w), " rv2" );
506506 indices.Set (height_axis, i_start_h + dheight);
507507 indices.Set (width_axis, i_start_w + dwidth);
508- return tvm::sum (div (x (indices), divide_factor), { dheight, dwidth });
509- }, " tensor" , " adaptive_pool_avg" );
508+ return tvm::sum (x (indices), { dheight, dwidth });
509+ }, " tensor" , " adaptive_pool_sum" );
510+
511+ return tvm::compute (out_shape, [&](const Array<Var>& output) {
512+ Array<Expr> indices;
513+ for (const Var& var : output) indices.push_back (var);
514+ auto i_start_h = start_index (output[height_axis], out_height, height);
515+ auto i_end_h = end_index (output[height_axis], out_height, height);
516+ auto i_start_w = start_index (output[width_axis], out_width, width);
517+ auto i_end_w = end_index (output[width_axis], out_width, width);
518+ auto divide_factor = tvm::cast (x->dtype , (i_end_h - i_start_h)
519+ * (i_end_w - i_start_w));
520+ return div (pool_sum (indices), divide_factor);
521+ }, " tensor" , kElementWise );
510522 } else {
511523 LOG (ERROR) << " Unrecognized pool_type: " << pool_type;
512524 return x;
0 commit comments