Skip to content

Commit

Permalink
layernorm support fp8 (PaddlePaddle#65791)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangzheee authored Jul 8, 2024
1 parent 4b044d5 commit bf59751
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 35 deletions.
10 changes: 7 additions & 3 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2425,10 +2425,14 @@ void FusedLayerNormInferMeta(const MetaTensor& x,
if (residual_out && !norm_weight && !norm_bias) {
out->set_dtype(x.dtype());
} else {
if (quant_scale <= 0.0f) {
out->set_dtype(x.dtype());
if (quant_scale > 0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
out->set_dtype(phi::DataType::INT8);
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
out->set_dtype(phi::DataType::FLOAT8_E4M3FN);
}
} else {
out->set_dtype(phi::DataType::INT8);
out->set_dtype(x.dtype());
}
}
out->set_layout(x.layout());
Expand Down
169 changes: 137 additions & 32 deletions paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,18 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input,
ClipFunc<float>(quant_value, min_bound, max_bound));
}

template <typename InType, typename OutType>
__forceinline__ __device__ OutType FP8QuantHelperFunc(const InType input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * input;
// float quant_value = input;
return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}

template <typename OutType,
typename SRC,
typename DST,
Expand All @@ -873,9 +885,9 @@ struct AffineQuantStore {
const float* gamma,
const float* beta,
const float quant_out_scale,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0)
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound)
: y(y),
row_size(row_size),
gamma(gamma),
Expand All @@ -901,11 +913,19 @@ struct AffineQuantStore {
float normalized_i = static_cast<float>(src[i]);
float normalized_val =
normalized_i * gamma_pack.elem[i] + beta_pack.elem[i];
y_pack.elem[i] = QuantHelperFunc<float, OutType>(normalized_val,
quant_out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
if constexpr (std::is_same_v<OutType, phi::dtype::float8_e4m3fn>) {
y_pack.elem[i] = FP8QuantHelperFunc<float, OutType>(normalized_val,
quant_out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
y_pack.elem[i] = QuantHelperFunc<float, OutType>(normalized_val,
quant_out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
}
*(reinterpret_cast<Pack<OutType, N>*>(y) + offset) = y_pack;
}
Expand Down Expand Up @@ -990,6 +1010,39 @@ void FusedLayerNormKernel(const Context& dev_ctx,
DenseTensor* residual_out,
DenseTensor* mean,
DenseTensor* variance) {
if (out->dtype() == phi::DataType::INT8 ||
out->dtype() == phi::DataType::FLOAT8_E4M3FN) {
PADDLE_ENFORCE_EQ(
quant_scale != 0.0f,
true,
phi::errors::InvalidArgument(
"Quant fused_bias_residual_layernorm'output, must has quant_scale, "
"quant_scale!=0, but quant_scale = %f ",
quant_scale));
PADDLE_ENFORCE_EQ(quant_round_type == 0 || quant_round_type == 1,
true,
phi::errors::InvalidArgument(
"Quant fused_bias_residual_layernorm'output, must "
"has quant_round_type, "
"quant_round_type = 0 or quant_round_type = 1, but "
"quant_scale = %d ",
quant_scale));
PADDLE_ENFORCE_EQ(quant_max_bound != 0.0f,
true,
phi::errors::InvalidArgument(
"Quant fused_bias_residual_layernorm'output, must "
"has quant_max_bound and "
"quant_max_bound!=0, but quant_max_bound = %f ",
quant_scale));
PADDLE_ENFORCE_EQ(quant_min_bound != 0.0f,
true,
phi::errors::InvalidArgument(
"Quant fused_bias_residual_layernorm'output, must "
"has quant_min_bound and "
"quant_min_bound!=0, but quant_min_bound = %f ",
quant_scale));
}

using U = phi::funcs::LayerNormParamType<T>;
const T* x_data = x.data<T>();
const U* norm_weight_data =
Expand Down Expand Up @@ -1034,21 +1087,7 @@ void FusedLayerNormKernel(const Context& dev_ctx,
T* residual_out_data = dev_ctx.template Alloc<T>(residual_out);
const T* residual_data = residual.get().data<T>();
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
if (quant_scale <= 0.0f) {
T* out_data = dev_ctx.template Alloc<T>(out);
residual_bias_add_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
x_data,
residual_data,
bias_data,
norm_weight_data,
norm_bias_data,
residual_out_data,
nullptr,
out_data,
mean_data,
variance_data);
} else {
if (out->dtype() == phi::DataType::INT8) {
// Quantize and output int8.
int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
SkipLoadAndStoreResidual<T> load(x_data,
Expand All @@ -1074,17 +1113,52 @@ void FusedLayerNormKernel(const Context& dev_ctx,
epsilon,
mean_data /*ln_mean_data*/,
variance_data /*ln_var_data*/);
} else if (out->dtype() == phi::DataType::FLOAT8_E4M3FN) {
// Quantize and output float8_e4m3fn.
phi::dtype::float8_e4m3fn* out_data =
dev_ctx.template Alloc<phi::dtype::float8_e4m3fn>(out);
SkipLoadAndStoreResidual<T> load(x_data,
bias_data,
residual_data,
residual_out_data,
residual_alpha,
cols);
AffineQuantStore<phi::dtype::float8_e4m3fn, U, T, true, true> store(
out_data,
cols,
norm_weight_data,
norm_bias_data,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchLayerNorm<decltype(load), decltype(store), U>(
dev_ctx.stream(),
load,
store,
rows,
cols,
epsilon,
mean_data /*ln_mean_data*/,
variance_data /*ln_var_data*/);
} else {
// No Quantize.
T* out_data = dev_ctx.template Alloc<T>(out);
residual_bias_add_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
x_data,
residual_data,
bias_data,
norm_weight_data,
norm_bias_data,
residual_out_data,
nullptr,
out_data,
mean_data,
variance_data);
}
} else {
if (quant_scale <= 0.0f) {
T* out_data = dev_ctx.template Alloc<T>(out);
layernorm_helper.ComputeForward(x_data,
norm_weight_data,
norm_bias_data,
out_data,
mean_data,
variance_data);
} else {
if (out->dtype() == phi::DataType::INT8) {
// Quantize and output int8.
int8_t* out_data = dev_ctx.template Alloc<int8_t>(out);
DirectLoad<T, U> load(x_data, cols);
Expand All @@ -1104,6 +1178,37 @@ void FusedLayerNormKernel(const Context& dev_ctx,
epsilon,
mean_data,
variance_data);
} else if (out->dtype() == phi::DataType::FLOAT8_E4M3FN) {
// Quantize and output float8_e4m3fn.
phi::dtype::float8_e4m3fn* out_data =
dev_ctx.template Alloc<phi::dtype::float8_e4m3fn>(out);
DirectLoad<T, U> load(x_data, cols);
AffineQuantStore<phi::dtype::float8_e4m3fn, U, T, true, true> store(
out_data,
cols,
norm_weight_data,
norm_bias_data,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
DispatchLayerNorm<decltype(load), decltype(store), U>(dev_ctx.stream(),
load,
store,
rows,
cols,
epsilon,
mean_data,
variance_data);
} else {
// No Quantize.
T* out_data = dev_ctx.template Alloc<T>(out);
layernorm_helper.ComputeForward(x_data,
norm_weight_data,
norm_bias_data,
out_data,
mean_data,
variance_data);
}
}
}
Expand Down

0 comments on commit bf59751

Please sign in to comment.