Skip to content

Commit

Permalink
register bf16 dtype for top_p_sampling kernel (PaddlePaddle#67769)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Aug 28, 2024
1 parent dddda15 commit 4eb9813
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ namespace cub = hipcub;
#include <cub/cub.cuh>
#endif

#if defined(__CUDACC__) && CUDA_VERSION >= 11060
#define CUDA_BFLOAT16_AVALIABLE
#include <cuda_bf16.h>
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/kernel_registry.h"
Expand Down Expand Up @@ -55,7 +60,7 @@ struct DataTypeTraits<phi::dtype::float16> {
using DataType = half;
};

#ifdef PADDLE_CUDA_BF16
#ifdef CUDA_BFLOAT16_AVALIABLE
template <>
struct DataTypeTraits<phi::dtype::bfloat16> {
using DataType = __nv_bfloat16;
Expand Down Expand Up @@ -1241,6 +1246,18 @@ void TopPSamplingKernel(const Context& dev_ctx,

} // namespace phi

#ifdef CUDA_BFLOAT16_AVALIABLE
PD_REGISTER_KERNEL(top_p_sampling,
GPU,
ALL_LAYOUT,
phi::TopPSamplingKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(top_p_sampling,
GPU,
ALL_LAYOUT,
Expand All @@ -1250,3 +1267,4 @@ PD_REGISTER_KERNEL(top_p_sampling,
int,
int64_t,
phi::dtype::float16) {}
#endif

0 comments on commit 4eb9813

Please sign in to comment.