@@ -23,23 +23,12 @@ limitations under the License. */
2323#include " paddle/fluid/operators/math/math_function.h"
2424#include " paddle/fluid/operators/math/pooling.h"
2525#if defined(__HIPCC__) || defined(__NVCC__)
26- #include " paddle/fluid/operators/reduce_ops/cub_reduce.h"
26+ #include " paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
27+ #include " paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
2728#endif
2829
2930namespace paddle {
3031namespace operators {
31- template <typename T>
32- struct DivideFunctor {
33- HOSTDEVICE explicit inline DivideFunctor (int n) : n_inv((T)(1.0 / n)) {}
34-
35- template <typename U>
36- HOSTDEVICE inline U operator ()(const U& x) const {
37- return x * static_cast <U>(n_inv);
38- }
39-
40- private:
41- T n_inv;
42- };
4332
4433using Tensor = framework::Tensor;
4534
@@ -219,9 +208,7 @@ class PoolKernel : public framework::OpKernel<T> {
219208 adaptive) { // for adaptive_avg_pool2d && output_size == 1
220209#if defined(__HIPCC__) || defined(__NVCC__)
221210 auto stream = dev_ctx.stream ();
222- TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
223- *in_x, out, reduce_dim, static_cast <T>(0 ), cub::Sum (),
224- DivideFunctor<T>(reduce_num), stream);
211+ TensorReduceFunc<T, T, CustomMean>(*in_x, out, reduce_dim, stream);
225212#else // for cpu
226213 paddle::operators::math::Pool2dFunctor<
227214 DeviceContext, paddle::operators::math::AvgPool<T>, T>
0 commit comments