Skip to content

Commit

Permalink
compilation optimization for dirichlet_kernel (#57815)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored Sep 28, 2023
1 parent e497279 commit a46df40
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions paddle/phi/kernels/gpu/dirichlet_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

#ifdef PADDLE_WITH_CUDA
#include <curand_kernel.h>
Expand Down Expand Up @@ -99,15 +101,14 @@ struct DirichletSampler<GPUContext, T> {
gamma_sum.Resize(new_shape);
dev_ctx.template Alloc<T>(&gamma_sum);

funcs::ReduceKernelImpl<GPUContext, T, T, funcs::SumFunctor>(
dev_ctx,
gamma_samples,
&gamma_sum,
{new_shape.size() - 1},
true,
false);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out);
phi::SumRawKernel<T, GPUContext>(dev_ctx,
gamma_samples,
{new_shape.size() - 1},
true,
false,
gamma_sum.dtype(),
&gamma_sum);
phi::DivideKernel<T, GPUContext>(dev_ctx, gamma_samples, gamma_sum, out);
}
};
} // namespace phi
Expand Down

0 comments on commit a46df40

Please sign in to comment.