Skip to content

Optimize layer norm forward when cols is 1024. #39167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 26, 2022
230 changes: 223 additions & 7 deletions paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License. */
namespace paddle {
namespace operators {

#define LN_NUM_COLS 1024

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
Expand Down Expand Up @@ -153,6 +155,191 @@ __global__ void FusedLayernormResidualDropoutBias(
invvar);
}

/*
* @brief layernorm(residual + dropout(x));
* Conditions:
* (1) The number of cols is 1024;
* (2) layer_norm scale and bias is not null;
* (3) linear bias is null;
* @param
* rows: batch_size * seq_len
* cols: 1024
* x_: [rows, cols], inputs
* residual_:[rows, cols]
* gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result
* residual_out_: [rows, cols], residual + dropout(src)
* y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars
*/
template <
typename T, typename U, typename ScaleT = U, typename MaskType = uint8_t,
int VecSize = 8, int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16,
int ELTS_PER_ROW = 1024, int THREADS_PER_WARP = 32,
int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
int rows, int cols, uint64_t seed, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test,
const uint64_t increment, const float epsilon, const T *__restrict__ x_ptr,
const T *__restrict__ residual_ptr, const ScaleT *__restrict__ gamma_ptr,
const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) {
using Vec = platform::AlignedVector<T, VecSize>;
using Vec_scale = platform::AlignedVector<ScaleT, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;

const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int lane = tidx % THREADS_PER_WARP; // 0, 1, ..., 31
const int warp = tidx / THREADS_PER_WARP; // 0, 1, 2, 3
const int warp_n = warp % WARPS_N; // 0
const int warp_m = warp / WARPS_N; // 0, 1, 2, 3

const int c = warp_n * THREADS_PER_WARP + lane; // lane
const int r = bidx * ROWS_PER_CTA + warp_m; // row id

int idx = r * LN_NUM_COLS + c;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);

T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);

Vec_scale gamma[LDGS];
Vec_scale beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
platform::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
col += THREADS_PER_ROW;
}

constexpr U rn = 1.f / U(LN_NUM_COLS);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
Vec residual[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize,
&x[it]);
platform::Load<T, VecSize>(
residual_ptr + row * LN_NUM_COLS + col * VecSize, &residual[it]);
col += THREADS_PER_ROW;
}

MaskStoreT mask_vec[LDGS];
if (!is_test) {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
float rand[VecSize];
RandVec<VecSize>(&state, rand);
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
#pragma unroll
mask_vec[it][jt] = static_cast<MaskType>(rand[jt] >= dropout_prob);
}
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
mask_vec[it][jt] = static_cast<MaskType>(1);
}
}
}

// 4 * 8
U xf[LDGS * VecSize];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
}

// store dropout_residual_out and mask_out
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(
x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize);
platform::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}

U mu_local = 0.f;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
mu_local += xf[it * VecSize + jt];
}
}

#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
mu_local *= rn;
if (lane == 0) {
mean_out_ptr[row] = mu_local;
}
U var_local = 0.f;

#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
U diff = xf[it * VecSize + jt] - mu_local;
var_local += diff * diff;
}
}

#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) {
// Note: the stored var is different for paddle(ln) and apex (fast ln).
// var_out_ptr[row] = rsigma;
var_out_ptr[row] = var_local * rn;
}

#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// use fp16 to compute
// ScaleT tmp = static_cast<ScaleT>(rsigma * (xf[it * VecSize + jt] -
// mu_local));
// x[it][jt] = gamma[it][jt] * tmp + beta[it][jt];
// cast to fp32 to compute
U tmp = rsigma * (static_cast<U>(xf[it * VecSize + jt]) - mu_local);
x[it][jt] = static_cast<T>(static_cast<U>(gamma[it][jt]) * tmp +
static_cast<U>(beta[it][jt]));
}
}

#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
platform::Store<T, VecSize>(x[it],
y_ptr + row * LN_NUM_COLS + col * VecSize);
col += THREADS_PER_ROW;
}
}
}

/**
* @brief layernorm(residual + dropout(src + bias));
* @param
Expand Down Expand Up @@ -205,6 +392,13 @@ void LaunchLayernormResidualDropoutBias(
return;
}

bool can_call_1024_kernel = false;
if (cols == 1024 && scale != nullptr && layernorm_bias != nullptr &&
bias == nullptr) {
can_call_1024_kernel = true;
}
VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel;

const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) {
int blockDim = GetDesiredBlockDim(cols);
Expand All @@ -215,13 +409,35 @@ void LaunchLayernormResidualDropoutBias(
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
} else {
int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias<
T, uint8_t, VecSize, U,
ScaleBiasWithSameTypeX><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);

const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;

// Note: the grid can not exceed max_grid of the gpu.
const int grid =
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA)));
fused_ln_fwd_1024_kernel<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t,
VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, scale, layernorm_bias, mask_data,
mean, var, dst, layernorm_dst);
} else {
int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias<
T, uint8_t, VecSize, U,
ScaleBiasWithSameTypeX><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, bias, scale, layernorm_bias,
mask_data, dst, layernorm_dst, mean, var);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,18 @@ struct TestFusedLayernormResidualDropoutBias {
ctx = reinterpret_cast<platform::CUDADeviceContext *>(devicectx);
}

TestFusedLayernormResidualDropoutBias(int _rows, int _cols,
uint64_t _seed = 0,
float _dropout_prob = 0.0,
float _epsilon = 0.00001f,
bool _is_upscale_in_train = false,
bool _is_test = false) {
TestFusedLayernormResidualDropoutBias(
int _rows, int _cols, uint64_t _seed = 0, float _dropout_prob = 0.0,
float _epsilon = 0.00001f, bool _is_upscale_in_train = false,
bool _is_test = false, bool _has_bias = true) {
rows = _rows;
cols = _cols;
seed = _seed;
dropout_prob = _dropout_prob;
epsilon = _epsilon;
is_upscale_in_train = _is_upscale_in_train;
is_test = _is_test;
has_bias = true;
has_bias = _has_bias;
has_scale = true;
has_layernorm_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
Expand Down Expand Up @@ -283,7 +281,6 @@ static void BaseTest(const bool is_fp16 = false) {
}
}
}

TEST(FusedDropout, GPUFusedLayernormResidualDropoutBias) { BaseTest<float>(); }

TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasDouble) {
Expand Down Expand Up @@ -330,3 +327,12 @@ TEST(FusedDropout, GPUFusedLayernormResidualDropoutLargeShape) {
test.Run();
test.CheckOut(static_cast<float>(1e-4));
}

TEST(FusedDropout, GPUFusedLayernormResidualDropoutFp16MLperf) {
const int rows = 512;
const int cols = 1024;
TestFusedLayernormResidualDropoutBias<platform::float16> test(
rows, cols, 0, 0, 0.00001f, false, false, false);
test.Run();
test.CheckOut(static_cast<platform::float16>(1e-2));
}
Loading