Skip to content

Extend forward fast layer_norm kernel to support more dimensions. #43118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 2, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,15 @@ void LaunchLayernormResidualDropoutBias(
#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1280); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1536); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1792); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(2048); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)

bool can_call_fast_ln_kernel = false;
if ((cols == 768 || cols == 1024 || cols == 4096) && scale != nullptr &&
layernorm_bias != nullptr) {
if (((cols >= 768 && cols <= 2048 && cols % 256 == 0) || cols == 4096) &&
scale != nullptr && layernorm_bias != nullptr) {
can_call_fast_ln_kernel = true;
}
VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel;
Expand Down
89 changes: 62 additions & 27 deletions paddle/fluid/operators/layer_norm_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

#define LN_NUM_COLS 1024

inline static int GetDesiredBlockDim(int64_t block_dim) {
#ifdef __HIPCC__
const int kMaxBlockDim = 256;
Expand Down Expand Up @@ -183,11 +181,12 @@ template <typename T, typename U, typename ScaleT = U, int VecSize = 8,
int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
__global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
int rows, int cols, const float epsilon, const T *__restrict__ x_ptr,
const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ y_ptr) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;

Expand All @@ -210,12 +209,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
col += THREADS_PER_ROW;
}

constexpr U rn = 1.f / U(LN_NUM_COLS);
constexpr U rn = 1.f / U(ELTS_PER_ROW);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
col += THREADS_PER_ROW;
}
U xf[LDGS * VecSize];
Expand All @@ -235,6 +234,23 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = mu_local;
}
__syncthreads();
mu_local = smem[warp_m];
}

mu_local *= rn;
if (lane == 0) {
mean_out_ptr[row] = mu_local;
Expand All @@ -254,6 +270,24 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}

if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = var_local;
}
__syncthreads();
var_local = smem[warp_m];
}

// Note: to assure if it is right for double
U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) {
Expand All @@ -277,7 +311,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(

#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW;
}
}
Expand Down Expand Up @@ -416,10 +450,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
const int r = bidx * ROWS_PER_CTA + warp_m;
const int c = warp_n * THREADS_PER_WARP + lane;

static_assert(LN_NUM_COLS == THREADS_PER_ROW * LDGS * VecSize, "");
static_assert(ELTS_PER_ROW == THREADS_PER_ROW * LDGS * VecSize, "");

// smem for column reduction
__shared__ U smem_[ROWS_PER_CTA * LN_NUM_COLS];
__shared__ U smem_[ROWS_PER_CTA * ELTS_PER_ROW];

U dgamma_sum[LDGS * VecSize];
U dbeta_sum[LDGS * VecSize];
Expand All @@ -434,7 +468,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N];

// step-1: compute dx and local results of dscale and dbias
constexpr float rn = 1.f / static_cast<float>(LN_NUM_COLS);
constexpr float rn = 1.f / static_cast<float>(ELTS_PER_ROW);
Vec_scale gamma[LDGS];
int col = c;
#pragma unroll
Expand All @@ -452,12 +486,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
phi::Load<T, VecSize>(dout_ptr + row * LN_NUM_COLS + col * VecSize,
phi::Load<T, VecSize>(dout_ptr + row * ELTS_PER_ROW + col * VecSize,
&dout[it]);
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
if (isFusedDropoutResidualLn) {
phi::Load<MaskType, VecSize>(
mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]);
mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]);
}

col += THREADS_PER_ROW;
Expand Down Expand Up @@ -551,23 +585,24 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], dx_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it],
dx_ptr + row * ELTS_PER_ROW + col * VecSize);
if (isFusedDropoutResidualLn) {
phi::Store<T, VecSize>(
dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize);
dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
}
col += THREADS_PER_ROW;
}
}

// step-2: column reduction of dscale and dbias for each thread block.
// each block's sum: [4 * 1024] -> [1 * 1024]
enum { NUM_RES = LN_NUM_COLS / THREADS_PER_CTA }; // 1024/128 = 8
static_assert(NUM_RES * THREADS_PER_CTA == LN_NUM_COLS, "");
enum { NUM_RES = ELTS_PER_ROW / THREADS_PER_CTA }; // 1024/128 = 8
static_assert(NUM_RES * THREADS_PER_CTA == ELTS_PER_ROW, "");

U *smem_write;

smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; // [4 * 1024]
smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize]; // [4 * 1024]
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
Expand All @@ -583,12 +618,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dbeta_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA];
smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
}
}
__syncthreads();

smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize];
smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
Expand All @@ -603,19 +638,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dgamma_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA];
smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
}
}

// the shape of results:(#blocks, 1024)
U *dgamma_part =
static_cast<U *>(dgamma_temp_ptr) + bidx * LN_NUM_COLS + tidx;
static_cast<U *>(dgamma_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dgamma_sum[jt];
dgamma_part += THREADS_PER_CTA;
}

U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * LN_NUM_COLS + tidx;
U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dbeta_sum[jt];
dbeta_part += THREADS_PER_CTA;
Expand All @@ -640,7 +675,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_,
ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) {
using Vec = phi::AlignedVector<U, VecSize>;
static_assert(VEC_COLS == LN_NUM_COLS / VecSize, "");
static_assert(VEC_COLS == ELTS_PER_ROW / VecSize, "");

const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
Expand All @@ -656,8 +691,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
__shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize];

for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) {
const U *dg_part_ptr = (dg_part_) + r * LN_NUM_COLS + col * VecSize;
const U *db_part_ptr = (db_part_) + r * LN_NUM_COLS + col * VecSize;
const U *dg_part_ptr = (dg_part_) + r * ELTS_PER_ROW + col * VecSize;
const U *db_part_ptr = (db_part_) + r * ELTS_PER_ROW + col * VecSize;

U dg_sum[VecSize];
U db_sum[VecSize];
Expand All @@ -669,8 +704,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
Vec db;
phi::Load<U, VecSize>(dg_part_ptr, &dg);
phi::Load<U, VecSize>(db_part_ptr, &db);
dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
dg_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;
db_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;

#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
Expand Down
117 changes: 67 additions & 50 deletions paddle/phi/kernels/gpu/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void LayerNormKernel(const Context &dev_ctx,
PADDLE_ENFORCE_EQ(
scale->dtype(),
bias->dtype(),
phi::errors::InvalidArgument("Thie Scale and Bias of layer_norm op "
phi::errors::InvalidArgument("This Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
Expand Down Expand Up @@ -131,59 +131,75 @@ void LayerNormKernel(const Context &dev_ctx,
} \
} while (0)

#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, feature_size) \
case (feature_size): { \
constexpr int WARPS_N = feature_size < 1024 ? 1 : (feature_size / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = static_cast<int>( \
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA))); \
paddle::operators::fast_ln_fwd_kernel< \
T, \
U, \
ScaleT, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>( \
batch_size, \
feature_size, \
epsilon, \
x_data, \
static_cast<const ScaleT *>(void_scale_data), \
static_cast<const ScaleT *>(void_bias_data), \
mean_data, \
var_data, \
y_data); \
} break

#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD(ScaleT) \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 768); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1024); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1280); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1536); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1792); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 2048); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 4096)

#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
if (feature_size == 1024 && scale != nullptr && bias != nullptr) {
can_call_1024_kernel = true;
bool can_call_fast_kernel = false;
if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
feature_size == 4096) &&
scale != nullptr && bias != nullptr) {
// can_call_fast_kernel = true;
can_call_fast_kernel = false;
}
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);

const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;

const int grid = static_cast<int>(
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));

if (can_call_fast_kernel) {
if (is_scale_bias_same_dtype_with_x) {
paddle::operators::ln_fwd_1024_kernel<
T,
U,
T,
VecSize,
WARPS_M,
WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size,
feature_size,
epsilon,
x_data,
static_cast<const T *>(void_scale_data),
static_cast<const T *>(void_bias_data),
mean_data,
var_data,
y_data);
switch (feature_size) {
PADDLE_LAUNCH_FAST_LAYERNORM_FWD(T);
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Only when feature_size is from 256 to 4096 and is diviaible by "
"256 is supported "
"now"));
break;
}
} else {
paddle::operators::ln_fwd_1024_kernel<
T,
U,
U,
VecSize,
WARPS_M,
WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size,
feature_size,
epsilon,
x_data,
static_cast<const U *>(void_scale_data),
static_cast<const U *>(void_bias_data),
mean_data,
var_data,
y_data);
switch (feature_size) {
PADDLE_LAUNCH_FAST_LAYERNORM_FWD(U);
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Only when feature_size is from 256 to 4096 and is diviaible by "
"is supported "
"now"));
break;
}
}
} else {
#endif
Expand All @@ -197,6 +213,7 @@ void LayerNormKernel(const Context &dev_ctx,
#endif

#undef PADDLE_LAUNCH_LAYERNORM_FWD
#undef PADDLE_LAUNCH_FAST_LAYERNORM_FWD
}

} // namespace phi
Expand Down