Skip to content

[DCU] fix fused_bias_act op #65399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions paddle/phi/kernels/funcs/load_store_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
namespace phi {
namespace funcs {

#ifndef PADDLE_WITH_HIP
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max) return max;
Expand Down Expand Up @@ -216,6 +215,6 @@ struct QuantStore<T, true> {
const T *smooth_;
const int cols_;
};
#endif

} // namespace funcs
} // namespace phi
4 changes: 0 additions & 4 deletions paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ COMMON_DECLARE_bool(use_fast_math);
namespace phi {
namespace fusion {

#ifndef PADDLE_WITH_HIP
template <typename T,
typename Functor,
int VecSize,
Expand Down Expand Up @@ -432,7 +431,6 @@ void DispatchWithDtype(const Context &dev_ctx,
float quant_min_bound,
DenseTensor *out,
UnusedVersion) {}
#endif

template <typename T, typename Context>
void FusedBiasActKernel(const Context &dev_ctx,
Expand All @@ -448,7 +446,6 @@ void FusedBiasActKernel(const Context &dev_ctx,
float quant_max_bound,
float quant_min_bound,
DenseTensor *out) {
#ifndef PADDLE_WITH_HIP
int rows = x.dims()[0];
int cols = x.dims()[1];
if (x.dtype() == phi::DataType::INT32) {
Expand Down Expand Up @@ -529,7 +526,6 @@ void FusedBiasActKernel(const Context &dev_ctx,
out,
typename DispatchDtypeTrait<T>::FuncVersion{});
}
#endif
}

} // namespace fusion
Expand Down
12 changes: 2 additions & 10 deletions paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
#endif
// for windows build
#define M_SQRT1_2 0.70710678118654752440

Expand All @@ -36,15 +34,10 @@ namespace fusion {
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
#ifdef PADDLE_WITH_HIP
assert(0 && "ROCM does not support FastGelu");
#else
return phi::GeluFwd<T, true>(x);
#endif
}
};

#ifndef PADDLE_WITH_HIP
template <typename T>
struct GeluComputeType;

Expand Down Expand Up @@ -119,7 +112,7 @@ struct ReluFunctor {
}
};

inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
inline gpuError_t GetNumBlocks(int64_t n, int *num_blocks) {
constexpr int kBlockSize = 128;
constexpr int kNumWaves = 16;

Expand All @@ -133,9 +126,8 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * max_thread_per_multiprocessor /
kBlockSize * kNumWaves));
return cudaSuccess;
return gpuSuccess;
}
#endif

} // namespace fusion
} // namespace phi
21 changes: 20 additions & 1 deletion paddle/phi/kernels/gpu/gelu_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ COMMON_DECLARE_bool(use_fast_math);

namespace phi {

#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIPCC__)
template <bool FastMode>
static __device__ __forceinline__ float FP32FastTanh(float x) {
#if __CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000
Expand All @@ -47,6 +47,16 @@ static __device__ __forceinline__ T GeluFwd(T x) {
return static_cast<T>(cast_x * 0.5f * (1.0f + tanh_out));
}

#ifdef PADDLE_WITH_HIP
template <bool FastMode>
static __device__ __forceinline__ __half GeluFwdHalf(__half x) {
const float cast_x = __half2float(x);
auto tanh_out = FP32FastTanh<FastMode>(0.79788456f * cast_x *
(1.0f + 0.044715f * cast_x * cast_x));
return __float2half(cast_x * 0.5f * (1.0f + tanh_out));
}
#endif

template <bool FastMode>
static __device__ __forceinline__ float FP32GeluBwd(float x, float y_g) {
auto tanh_out =
Expand All @@ -70,7 +80,11 @@ static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x,
ArrT in_arr = *reinterpret_cast<const ArrT*>(x + offset);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
#ifdef PADDLE_WITH_HIP
in_arr[i] = GeluFwdHalf<FastMode>(in_arr[i]);
#else
in_arr[i] = GeluFwd<half, FastMode>(in_arr[i]);
#endif
}
*reinterpret_cast<ArrT*>(y + offset) = in_arr;
}
Expand All @@ -91,8 +105,13 @@ static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x,
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
__half2 tmp_fp16_2;
#ifdef PADDLE_WITH_HIP
tmp_fp16_2.x = *reinterpret_cast<uint16_t*>(&x_in_arr[i]);
tmp_fp16_2.y = *reinterpret_cast<uint16_t*>(&y_g_in_arr[i]);
#else
tmp_fp16_2.x = x_in_arr[i];
tmp_fp16_2.y = y_g_in_arr[i];
#endif
float2 tmp_fp32_2 = __half22float2(tmp_fp16_2);
x_in_arr[i] =
__float2half(FP32GeluBwd<FastMode>(tmp_fp32_2.x, tmp_fp32_2.y));
Expand Down
9 changes: 6 additions & 3 deletions test/legacy_test/test_fused_bias_act_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def fused_act_bias_wrapper(


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
not core.is_compiled_with_cuda() and not core.is_compiled_with_rocm(),
"core is not compiled with CUDA or ROCm",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以不用修改,你可以看下 is_compiled_with_cuda 的实现,当初了为支持HIP和CUDA的兼容,当编译的时候是HIP的时候,也会返回True,可以参考

#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)

下面几行也是一样。

)
class TestFusedBiasActOp(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -640,7 +641,8 @@ def compute_baseline_output(self):


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
not core.is_compiled_with_cuda() and not core.is_compiled_with_rocm(),
"core is not compiled with CUDA or ROCm",
)
class TestAssert(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -703,7 +705,8 @@ def test_assert_case3(self):


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
not core.is_compiled_with_cuda() and not core.is_compiled_with_rocm(),
"core is not compiled with CUDA or ROCm",
)
class TestWithoutBias(unittest.TestCase):
def setUp(self):
Expand Down