From bf59751b7dbbf9cc21a87576894c3e670fb4e5ef Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 8 Jul 2024 19:03:06 +0800 Subject: [PATCH] layernorm support fp8 (#65791) --- paddle/phi/infermeta/multiary.cc | 10 +- .../fusion/gpu/fused_layernorm_kernel.cu | 169 ++++++++++++++---- 2 files changed, 144 insertions(+), 35 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1b68617882c5d1..b095b662bcaac9 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2425,10 +2425,14 @@ void FusedLayerNormInferMeta(const MetaTensor& x, if (residual_out && !norm_weight && !norm_bias) { out->set_dtype(x.dtype()); } else { - if (quant_scale <= 0.0f) { - out->set_dtype(x.dtype()); + if (quant_scale > 0) { + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + out->set_dtype(phi::DataType::INT8); + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + out->set_dtype(phi::DataType::FLOAT8_E4M3FN); + } } else { - out->set_dtype(phi::DataType::INT8); + out->set_dtype(x.dtype()); } } out->set_layout(x.layout()); diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu index 7d0e69a38fbb13..ccbf68d6e5112d 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu @@ -862,6 +862,18 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input, ClipFunc(quant_value, min_bound, max_bound)); } +template +__forceinline__ __device__ OutType FP8QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * input; + // float quant_value = input; + return static_cast( + ClipFunc(quant_value, min_bound, max_bound)); +} + template (src[i]); float normalized_val = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]; - y_pack.elem[i] = QuantHelperFunc(normalized_val, - quant_out_scale, - quant_round_type, - quant_max_bound, - quant_min_bound); + if constexpr (std::is_same_v) { + y_pack.elem[i] = FP8QuantHelperFunc(normalized_val, + quant_out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + y_pack.elem[i] = QuantHelperFunc(normalized_val, + quant_out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } } *(reinterpret_cast*>(y) + offset) = y_pack; } @@ -990,6 +1010,39 @@ void FusedLayerNormKernel(const Context& dev_ctx, DenseTensor* residual_out, DenseTensor* mean, DenseTensor* variance) { + if (out->dtype() == phi::DataType::INT8 || + out->dtype() == phi::DataType::FLOAT8_E4M3FN) { + PADDLE_ENFORCE_EQ( + quant_scale != 0.0f, + true, + phi::errors::InvalidArgument( + "Quant fused_bias_residual_layernorm'output, must has quant_scale, " + "quant_scale!=0, but quant_scale = %f ", + quant_scale)); + PADDLE_ENFORCE_EQ(quant_round_type == 0 || quant_round_type == 1, + true, + phi::errors::InvalidArgument( + "Quant fused_bias_residual_layernorm'output, must " + "has quant_round_type, " + "quant_round_type = 0 or quant_round_type = 1, but " + "quant_scale = %d ", + quant_scale)); + PADDLE_ENFORCE_EQ(quant_max_bound != 0.0f, + true, + phi::errors::InvalidArgument( + "Quant fused_bias_residual_layernorm'output, must " + "has quant_max_bound and " + "quant_max_bound!=0, but quant_max_bound = %f ", + quant_scale)); + PADDLE_ENFORCE_EQ(quant_min_bound != 0.0f, + true, + phi::errors::InvalidArgument( + "Quant fused_bias_residual_layernorm'output, must " + "has quant_min_bound and " + "quant_min_bound!=0, but quant_min_bound = %f ", + quant_scale)); + } + using U = phi::funcs::LayerNormParamType; const T* x_data = x.data(); const U* norm_weight_data = @@ -1034,21 +1087,7 @@ void FusedLayerNormKernel(const Context& dev_ctx, T* residual_out_data = dev_ctx.template Alloc(residual_out); const T* residual_data = residual.get().data(); const T* bias_data = bias ? bias.get().data() : nullptr; - if (quant_scale <= 0.0f) { - T* out_data = dev_ctx.template Alloc(out); - residual_bias_add_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - x_data, - residual_data, - bias_data, - norm_weight_data, - norm_bias_data, - residual_out_data, - nullptr, - out_data, - mean_data, - variance_data); - } else { + if (out->dtype() == phi::DataType::INT8) { // Quantize and output int8. int8_t* out_data = dev_ctx.template Alloc(out); SkipLoadAndStoreResidual load(x_data, @@ -1074,17 +1113,52 @@ void FusedLayerNormKernel(const Context& dev_ctx, epsilon, mean_data /*ln_mean_data*/, variance_data /*ln_var_data*/); + } else if (out->dtype() == phi::DataType::FLOAT8_E4M3FN) { + // Quantize and output float8_e4m3fn. + phi::dtype::float8_e4m3fn* out_data = + dev_ctx.template Alloc(out); + SkipLoadAndStoreResidual load(x_data, + bias_data, + residual_data, + residual_out_data, + residual_alpha, + cols); + AffineQuantStore store( + out_data, + cols, + norm_weight_data, + norm_bias_data, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + DispatchLayerNorm( + dev_ctx.stream(), + load, + store, + rows, + cols, + epsilon, + mean_data /*ln_mean_data*/, + variance_data /*ln_var_data*/); + } else { + // No Quantize. + T* out_data = dev_ctx.template Alloc(out); + residual_bias_add_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + x_data, + residual_data, + bias_data, + norm_weight_data, + norm_bias_data, + residual_out_data, + nullptr, + out_data, + mean_data, + variance_data); } } else { - if (quant_scale <= 0.0f) { - T* out_data = dev_ctx.template Alloc(out); - layernorm_helper.ComputeForward(x_data, - norm_weight_data, - norm_bias_data, - out_data, - mean_data, - variance_data); - } else { + if (out->dtype() == phi::DataType::INT8) { // Quantize and output int8. int8_t* out_data = dev_ctx.template Alloc(out); DirectLoad load(x_data, cols); @@ -1104,6 +1178,37 @@ void FusedLayerNormKernel(const Context& dev_ctx, epsilon, mean_data, variance_data); + } else if (out->dtype() == phi::DataType::FLOAT8_E4M3FN) { + // Quantize and output float8_e4m3fn. + phi::dtype::float8_e4m3fn* out_data = + dev_ctx.template Alloc(out); + DirectLoad load(x_data, cols); + AffineQuantStore store( + out_data, + cols, + norm_weight_data, + norm_bias_data, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + DispatchLayerNorm(dev_ctx.stream(), + load, + store, + rows, + cols, + epsilon, + mean_data, + variance_data); + } else { + // No Quantize. + T* out_data = dev_ctx.template Alloc(out); + layernorm_helper.ComputeForward(x_data, + norm_weight_data, + norm_bias_data, + out_data, + mean_data, + variance_data); } } }