Skip to content
Merged

Fc fp16 #44505

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ const std::vector<std::string> kTRTSubgraphPasses({

const std::vector<std::string> kDlnneSubgraphPasses({
"is_test_pass", //
"delete_dropout_op_pass" //
"delete_dropout_op_pass", //
"simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", //
"depthwise_conv_bn_fuse_pass", //
Expand All @@ -158,7 +158,10 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass"};
"conv_elementwise_add_fuse_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"fc_fuse_pass"};

const std::vector<std::string> kTrtLowerPrecisionPasses{
// "conv_bn_fuse_pass",
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/fc_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fc,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, phi::dtype::float16>,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, double>);
240 changes: 215 additions & 25 deletions paddle/phi/kernels/funcs/fc_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License. */
namespace phi {
namespace funcs {

using float16 = phi::dtype::float16;

template <typename T>
struct FcTypeTraits;

Expand Down Expand Up @@ -75,6 +77,216 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
}
}

template <typename T>
void AddReluKernel(
gpuStream_t stream, const int M, const int N, T* Y, const T* B, bool relu) {
if (N % 4 == 0) {
const int threads = 256;
const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<T>::Type trans_type;
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else {
bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
}
} else {
const int threads = 256;
const int blocks = M;

if (relu) {
InplaceAddReluKernel<T, true, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
} else {
InplaceAddReluKernel<T, false, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
}
}
}

#if defined(PADDLE_WITH_CUDA)

#include <cuda_fp16.h>

template <>
struct FcTypeTraits<float16> {
typedef half2 Type;
};

template <bool DoRelu>
__global__ void bias_relu_v2(const int num,
const half2* bias,
half2* data,
int K) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const half2 bias_ptr = bias[bias_idx];
const half2 in_ptr = data[tid];
half2 packed_val = __hadd2(bias_ptr, in_ptr);
if (DoRelu) {
#if __CUDA_ARCH__ >= 800
packed_val = __hmax2(__half2(0, 0), packed_val);
#else
packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val);
#endif
}
data[tid] = packed_val;
}
}

template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
const half* bias,
half* data) {
int offset = blockIdx.x * N;
for (int i = threadIdx.x; i < N; i += BlockDim) {
half temp;
#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
temp = __ldg(data + offset + i) + __ldg(bias + i);
#else
temp = data[offset + i] + bias[i];
#endif
if (DoRelu) {
#if __CUDA_ARCH__ >= 800
data[offset + i] = __hmax(0, temp);
#else
data[offset + i] = __hmul(__hgt(temp, 0), temp);
#endif
} else {
data[offset + i] = temp;
}
}
}

template <>
void AddReluKernel(cudaStream_t stream,
const int M,
const int N,
float16* Y,
const float16* B,
bool relu) {
if (N % 2 == 0) {
const int threads = 256;
const int num = M * N / 2;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<float16>::Type trans_type;
auto* bias_ptr_v2 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v2 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v2<true><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
} else {
bias_relu_v2<false><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
}
} else {
const int threads = 256;
const int blocks = M;
auto* halfB = reinterpret_cast<const half*>(B);
auto* halfY = reinterpret_cast<half*>(Y);
if (relu) {
InplaceAddReluKernel<true, threads>
<<<blocks, threads, 0, stream>>>(N, halfB, halfY);
} else {
InplaceAddReluKernel<false, threads>
<<<blocks, threads, 0, stream>>>(N, halfB, halfY);
}
}
}

#else

struct float16_4 {
float16 x, y, z, w;
};
template <>
struct FcTypeTraits<float16> {
typedef float16_4 Type;
};

template <bool DoRelu>
__global__ void bias_relu_v4(const int num,
const float16_4* bias,
float16_4* data,
int K) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const float16_4 bias_ptr = bias[bias_idx];
const float16_4 in_ptr = data[tid];
float16_4 packed_val;
packed_val.x = in_ptr.x + bias_ptr.x;
packed_val.y = in_ptr.y + bias_ptr.y;
packed_val.z = in_ptr.z + bias_ptr.z;
packed_val.w = in_ptr.w + bias_ptr.w;
if (DoRelu) {
packed_val.x = fmaxf(0.f, packed_val.x);
packed_val.y = fmaxf(0.f, packed_val.y);
packed_val.z = fmaxf(0.f, packed_val.z);
packed_val.w = fmaxf(0.f, packed_val.w);
}
data[tid] = packed_val;
}
}

template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
const float16* bias,
float16* data) {
int offset = blockIdx.x * N;

for (int i = threadIdx.x; i < N; i += BlockDim) {
float16 temp;
temp = data[offset + i] + bias[i];
if (DoRelu) {
data[offset + i] = fmaxf(0.f, temp);
} else {
data[offset + i] = temp;
}
}
}

template <>
void AddReluKernel(gpuStream_t stream,
const int M,
const int N,
float16* Y,
const float16* B,
bool relu) {
if (N % 4 == 0) {
const int threads = 256;
const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<float16>::Type trans_type;
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else {
bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
}
} else {
const int threads = 256;
const int blocks = M;

if (relu) {
InplaceAddReluKernel<true, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
} else {
InplaceAddReluKernel<false, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
}
}
}
#endif

template <typename DeviceContext, typename T>
void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const int M,
Expand Down Expand Up @@ -109,36 +321,14 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
}

// M * N
if (N % 4 == 0) {
const int threads = 256;
const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<T>::Type trans_type;
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v4<trans_type, true><<<blocks, threads, 0, context.stream()>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else {
bias_relu_v4<trans_type, false><<<blocks, threads, 0, context.stream()>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
}
} else {
const int threads = 256;
const int blocks = M;
if (relu) {
InplaceAddReluKernel<T, true, threads>
<<<blocks, threads, 0, context.stream()>>>(N, B, Y);
} else {
InplaceAddReluKernel<T, false, threads>
<<<blocks, threads, 0, context.stream()>>>(N, B, Y);
}
}
AddReluKernel(context.stream(), M, N, Y, B, relu);
}

template class FCFunctor<paddle::platform::CUDADeviceContext, float16>;
template class FCFunctor<paddle::platform::CUDADeviceContext, float>;
template class FCFunctor<paddle::platform::CUDADeviceContext, double>;

template class FCFunctor<GPUContext, float16>;
template class FCFunctor<GPUContext, float>;
template class FCFunctor<GPUContext, double>;

Expand Down