diff --git a/paddle/phi/kernels/gpu/dirichlet_kernel.cu b/paddle/phi/kernels/gpu/dirichlet_kernel.cu index 09d6a402e701a..912c84bf26c21 100644 --- a/paddle/phi/kernels/gpu/dirichlet_kernel.cu +++ b/paddle/phi/kernels/gpu/dirichlet_kernel.cu @@ -16,12 +16,14 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/reduce_functor.h" #include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" #ifdef PADDLE_WITH_CUDA #include @@ -99,15 +101,14 @@ struct DirichletSampler { gamma_sum.Resize(new_shape); dev_ctx.template Alloc(&gamma_sum); - funcs::ReduceKernelImpl( - dev_ctx, - gamma_samples, - &gamma_sum, - {new_shape.size() - 1}, - true, - false); - funcs::ElementwiseCompute, T>( - dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor(), out); + phi::SumRawKernel(dev_ctx, + gamma_samples, + {new_shape.size() - 1}, + true, + false, + gamma_sum.dtype(), + &gamma_sum); + phi::DivideKernel(dev_ctx, gamma_samples, gamma_sum, out); } }; } // namespace phi