Skip to content

Commit fcfaa10

Browse files
authored
(modified) fc support fp16 (#44540)
1 parent 3b0aa75 commit fcfaa10

File tree

1 file changed

+18
-43
lines changed

1 file changed

+18
-43
lines changed

paddle/phi/kernels/funcs/fc_functor.cu

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ struct FcTypeTraits<double> {
3636
typedef double4 Type;
3737
};
3838

39+
#if defined(PADDLE_WITH_CUDA)
40+
#include <cuda_fp16.h>
41+
42+
template <>
43+
struct FcTypeTraits<float16> {
44+
typedef half2 Type;
45+
};
46+
#else
47+
struct float16_4 {
48+
float16 x, y, z, w;
49+
};
50+
51+
template <>
52+
struct FcTypeTraits<float16> {
53+
typedef float16_4 Type;
54+
};
55+
#endif
56+
3957
template <typename T, bool DoRelu>
4058
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
4159
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -109,14 +127,6 @@ void AddReluKernel(
109127
}
110128

111129
#if defined(PADDLE_WITH_CUDA)
112-
113-
#include <cuda_fp16.h>
114-
115-
template <>
116-
struct FcTypeTraits<float16> {
117-
typedef half2 Type;
118-
};
119-
120130
template <bool DoRelu>
121131
__global__ void bias_relu_v2(const int num,
122132
const half2* bias,
@@ -200,46 +210,11 @@ void AddReluKernel(cudaStream_t stream,
200210
}
201211

202212
#else
203-
204-
struct float16_4 {
205-
float16 x, y, z, w;
206-
};
207-
template <>
208-
struct FcTypeTraits<float16> {
209-
typedef float16_4 Type;
210-
};
211-
212-
template <bool DoRelu>
213-
__global__ void bias_relu_v4(const int num,
214-
const float16_4* bias,
215-
float16_4* data,
216-
int K) {
217-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
218-
if (tid < num) {
219-
int bias_idx = tid % K;
220-
const float16_4 bias_ptr = bias[bias_idx];
221-
const float16_4 in_ptr = data[tid];
222-
float16_4 packed_val;
223-
packed_val.x = in_ptr.x + bias_ptr.x;
224-
packed_val.y = in_ptr.y + bias_ptr.y;
225-
packed_val.z = in_ptr.z + bias_ptr.z;
226-
packed_val.w = in_ptr.w + bias_ptr.w;
227-
if (DoRelu) {
228-
packed_val.x = fmaxf(0.f, packed_val.x);
229-
packed_val.y = fmaxf(0.f, packed_val.y);
230-
packed_val.z = fmaxf(0.f, packed_val.z);
231-
packed_val.w = fmaxf(0.f, packed_val.w);
232-
}
233-
data[tid] = packed_val;
234-
}
235-
}
236-
237213
template <bool DoRelu, int BlockDim>
238214
__global__ void InplaceAddReluKernel(const int N,
239215
const float16* bias,
240216
float16* data) {
241217
int offset = blockIdx.x * N;
242-
243218
for (int i = threadIdx.x; i < N; i += BlockDim) {
244219
float16 temp;
245220
temp = data[offset + i] + bias[i];

0 commit comments

Comments
 (0)