Skip to content

Commit c9aa55c

Browse files
kevinthesuntqchen
authored andcommitted
Split adaptive_pool2d_avg into sum and div (#4186)
1 parent 2460f90 commit c9aa55c

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

topi/include/topi/nn/pooling.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
492492
return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*)
493493
}, "tensor", "adaptive_pool_max");
494494
} else if (pool_type == kAvgPool) {
495-
return tvm::compute(out_shape, [&](const Array<Var>& output) {
495+
auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
496496
Array<Expr> indices;
497497
for (const Var& var : output) indices.push_back(var);
498498
auto i_start_h = start_index(output[height_axis], out_height, height);
@@ -505,8 +505,20 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
505505
auto dwidth = tvm::reduce_axis(Range(0, i_end_w - i_start_w), "rv2");
506506
indices.Set(height_axis, i_start_h + dheight);
507507
indices.Set(width_axis, i_start_w + dwidth);
508-
return tvm::sum(div(x(indices), divide_factor), { dheight, dwidth });
509-
}, "tensor", "adaptive_pool_avg");
508+
return tvm::sum(x(indices), { dheight, dwidth });
509+
}, "tensor", "adaptive_pool_sum");
510+
511+
return tvm::compute(out_shape, [&](const Array<Var>& output) {
512+
Array<Expr> indices;
513+
for (const Var& var : output) indices.push_back(var);
514+
auto i_start_h = start_index(output[height_axis], out_height, height);
515+
auto i_end_h = end_index(output[height_axis], out_height, height);
516+
auto i_start_w = start_index(output[width_axis], out_width, width);
517+
auto i_end_w = end_index(output[width_axis], out_width, width);
518+
auto divide_factor = tvm::cast(x->dtype, (i_end_h - i_start_h)
519+
* (i_end_w - i_start_w));
520+
return div(pool_sum(indices), divide_factor);
521+
}, "tensor", kElementWise);
510522
} else {
511523
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
512524
return x;

topi/python/topi/x86/pooling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ def traverse(OP):
147147
traverse(tensor.op)
148148
# schedule pool
149149
elif OP.tag.startswith('adaptive_pool'):
150+
if OP != outs[0].op:
151+
output = outs[0]
152+
output_fused = s[output].fuse(output.op.axis[0], output.op.axis[1])
153+
s[output].parallel(output_fused)
154+
150155
Pool = OP.output(0)
151156
_parallel_sch(s[Pool], outs[0].shape)
152157
else:

0 commit comments

Comments
 (0)