Skip to content

Adding fused_swiglu_probs_bwd op #10604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: dsv3_dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
// swiglu_probs_grad_op.cu
#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include <vector>

#include "paddle/extension.h"

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
using BFloat16 = __nv_bfloat16;
#else
struct BFloat16 {
uint16_t x;

__host__ __device__ BFloat16() : x(0) {}

__host__ __device__ BFloat16(float val) {
uint32_t* val_bits = reinterpret_cast<uint32_t*>(&val);
x = static_cast<uint16_t>(*val_bits >> 16);
}

__host__ __device__ operator float() const {
uint32_t bits = static_cast<uint32_t>(x) << 16;
return *reinterpret_cast<float*>(&bits);
}
};
#endif

template <int thread_per_block>
__global__ void SwigluProbsGradKernel(
const BFloat16* o1, // [seq_len*topk, moe_intermediate_size*2]
const BFloat16* do2_s, // [seq_len*topk, moe_intermediate_size]
const float* unzipped_probs, // [seq_len*topk, 1]
BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2]
float* probs_grad, // [seq_len*topk, 1]
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
int moe_intermediate_size) {
const int row_idx = blockIdx.x;
const int tid = threadIdx.x;

const BFloat16* o1_row = o1 + row_idx * moe_intermediate_size * 2;
const BFloat16* do2_s_row = do2_s + row_idx * moe_intermediate_size;
BFloat16* do1_row = do1 + row_idx * moe_intermediate_size * 2;
BFloat16* o2s_row = o2_s + row_idx * moe_intermediate_size;

float prob = unzipped_probs[row_idx];

__shared__ float sum_buffer[thread_per_block];

float local_probs_grad = 0.0f;

for (int i = tid; i < moe_intermediate_size; i += blockDim.x) {
float lhs = static_cast<float>(o1_row[i]);
float rhs = static_cast<float>(o1_row[i + moe_intermediate_size]);

float sig = 1.0f / (1.0f + expf(-lhs));
float tmp = sig * lhs;
float o2_val = tmp * rhs;

float do2_s_val = static_cast<float>(do2_s_row[i]);
float do2_val = do2_s_val * prob;

float x0_grad = do2_val * rhs * sig * (1.0f + lhs - tmp);
float x1_grad = do2_val * tmp;

do1_row[i] = BFloat16(x0_grad);
do1_row[i + moe_intermediate_size] = BFloat16(x1_grad);
o2s_row[i] = BFloat16(o2_val * prob);

local_probs_grad += do2_s_val * o2_val;
}

sum_buffer[tid] = local_probs_grad;
__syncthreads();

for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
sum_buffer[tid] += sum_buffer[tid + stride];
}
__syncthreads();
}

if (tid == 0) {
probs_grad[row_idx] = sum_buffer[0];
}
}

typedef struct __align__(8) {
__nv_bfloat16 x;
__nv_bfloat16 y;
__nv_bfloat16 z;
__nv_bfloat16 w;
}
bfloat16x4_t;

__device__ __forceinline__ float4 fast_swiglu_vec4(const bfloat16x4_t& lhs,
const bfloat16x4_t& rhs) {
const float x_f_x = __bfloat162float(lhs.x);
const float x_f_y = __bfloat162float(lhs.y);
const float x_f_z = __bfloat162float(lhs.z);
const float x_f_w = __bfloat162float(lhs.w);

const float y_f_x = __bfloat162float(rhs.x);
const float y_f_y = __bfloat162float(rhs.y);
const float y_f_z = __bfloat162float(rhs.z);
const float y_f_w = __bfloat162float(rhs.w);

const float silu_x = x_f_x * __frcp_rn(1.0f + __expf(-x_f_x));
const float silu_y = x_f_y * __frcp_rn(1.0f + __expf(-x_f_y));
const float silu_z = x_f_z * __frcp_rn(1.0f + __expf(-x_f_z));
const float silu_w = x_f_w * __frcp_rn(1.0f + __expf(-x_f_w));

return {silu_x * y_f_x, silu_y * y_f_y, silu_z * y_f_z, silu_w * y_f_w};
}

__device__ __forceinline__ float4 f4_prod(const float4& x_f,
const float4& y_f) {
return {x_f.x * y_f.x, x_f.y * y_f.y, x_f.z * y_f.z, x_f.w * y_f.w};
}
__device__ __forceinline__ float4 f4_prod(const float4& x_f, const float& y_f) {
return {x_f.x * y_f, x_f.y * y_f, x_f.z * y_f, x_f.w * y_f};
}
__device__ __forceinline__ float4 f4_add(const float4& x_f, const float& y_f) {
return {x_f.x + y_f, x_f.y + y_f, x_f.z + y_f, x_f.w + y_f};
}
__device__ __forceinline__ float4 f4_add(const float4& x_f, const float4& y_f) {
return {x_f.x + y_f.x, x_f.y + y_f.y, x_f.z + y_f.z, x_f.w + y_f.w};
}
__device__ __forceinline__ float4 f4_sub(const float4& x_f, const float4& y_f) {
return {x_f.x - y_f.x, x_f.y - y_f.y, x_f.z - y_f.z, x_f.w - y_f.w};
}
__device__ __forceinline__ float4 fast_sig_vec4(const float4& x_vec4) {
const float sig_x = __frcp_rn(1.0f + __expf(-x_vec4.x));
const float sig_y = __frcp_rn(1.0f + __expf(-x_vec4.y));
const float sig_z = __frcp_rn(1.0f + __expf(-x_vec4.z));
const float sig_w = __frcp_rn(1.0f + __expf(-x_vec4.w));
return {sig_x, sig_y, sig_z, sig_w};
}
__device__ __forceinline__ float4
load_and_cast_float4(const bfloat16x4_t* x_vec4_ptr) {
bfloat16x4_t x_vec4 = *x_vec4_ptr;
return {
static_cast<float>(x_vec4.x),
static_cast<float>(x_vec4.y),
static_cast<float>(x_vec4.z),
static_cast<float>(x_vec4.w),
};
}
__device__ __forceinline__ void cast_and_store_bf16x4(bfloat16x4_t* dst_ptr,
const float4& x_vec4) {
*dst_ptr = {static_cast<__nv_bfloat16>(x_vec4.x),
static_cast<__nv_bfloat16>(x_vec4.y),
static_cast<__nv_bfloat16>(x_vec4.z),
static_cast<__nv_bfloat16>(x_vec4.w)};
}
__device__ __forceinline__ float mreduce_f4(const float4& x_f4,
const float4& y_f4) {
float x_m = x_f4.x * y_f4.x;
float y_m = x_f4.y * y_f4.y;
float z_m = x_f4.z * y_f4.z;
float w_m = x_f4.w * y_f4.w;
return {x_m + y_m + z_m + w_m};
}

template <int thread_per_block>
__global__ void SwigluProbsGradKernelVec4(
const BFloat16* o1, // [seq_len*topk, moe_intermediate_size*2]
const BFloat16* do2_s, // [seq_len*topk, moe_intermediate_size]
const float* unzipped_probs, // [seq_len*topk, 1]
BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2]
float* probs_grad, // [seq_len*topk, 1]
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
int moe_intermediate_size) {
constexpr int numel_per_thread = 4;
constexpr int k_warp_size = 32;
const int row_idx = blockIdx.x;
const int tid = threadIdx.x;

const BFloat16* o1_row = o1 + row_idx * moe_intermediate_size * 2;
const BFloat16* do2_s_row = do2_s + row_idx * moe_intermediate_size;
const bfloat16x4_t* o1_row_left_half_vec4 =
reinterpret_cast<const bfloat16x4_t*>(o1_row);
const bfloat16x4_t* do2_s_row_vec4 =
reinterpret_cast<const bfloat16x4_t*>(do2_s_row);
const bfloat16x4_t* o1_row_right_half_vec4 =
reinterpret_cast<const bfloat16x4_t*>(o1_row + moe_intermediate_size);
BFloat16* do1_row = do1 + row_idx * moe_intermediate_size * 2;
BFloat16* o2s_row = o2_s + row_idx * moe_intermediate_size;
bfloat16x4_t* do1_row_vec4 = reinterpret_cast<bfloat16x4_t*>(do1_row);
bfloat16x4_t* o2s_row_vec4 = reinterpret_cast<bfloat16x4_t*>(o2s_row);

float prob = unzipped_probs[row_idx];
__shared__ float sum_buffer[thread_per_block];

float local_probs_grad = 0.0f;

const int vec_numel = moe_intermediate_size / numel_per_thread;
for (int i = tid; i < vec_numel; i += blockDim.x) {
float4 lhs_vec4 = load_and_cast_float4(o1_row_left_half_vec4 + i);
float4 rhs_vec4 = load_and_cast_float4(o1_row_right_half_vec4 + i);
float4 do2_s_val_vec4 = load_and_cast_float4(do2_s_row_vec4 + i);
float4 sig_vec4 = fast_sig_vec4(lhs_vec4);
float4 tmp_vec4 = f4_prod(sig_vec4, lhs_vec4);
float4 o2_val_vec4 = f4_prod(tmp_vec4, rhs_vec4);
float4 o2s_val_vec4 = f4_prod(o2_val_vec4, prob);
float4 do2_val_vec4 = f4_prod(do2_s_val_vec4, prob);
float4 x0_grad_vec4 = f4_prod(
do2_val_vec4,
f4_prod(rhs_vec4,
f4_prod(sig_vec4, (f4_sub(f4_add(lhs_vec4, 1.0f), tmp_vec4)))));
float4 x1_grad_vec4 = f4_prod(do2_val_vec4, tmp_vec4);
cast_and_store_bf16x4(do1_row_vec4 + i, x0_grad_vec4);
cast_and_store_bf16x4(do1_row_vec4 + i + vec_numel, x1_grad_vec4);
cast_and_store_bf16x4(o2s_row_vec4 + i, o2s_val_vec4);
local_probs_grad += mreduce_f4(do2_s_val_vec4, o2_val_vec4);
}

sum_buffer[tid] = local_probs_grad;
__syncthreads();

#pragma unroll
for (int stride = blockDim.x / 2; stride >= k_warp_size; stride >>= 1) {
if (tid < stride) {
sum_buffer[tid] += sum_buffer[tid + stride];
}
__syncthreads();
}

if (tid < k_warp_size) {
local_probs_grad = sum_buffer[tid];
#pragma unroll
for (int offset = k_warp_size / 2; offset > 0; offset >>= 1) {
local_probs_grad +=
__shfl_down_sync(0xFFFFFFFF, local_probs_grad, offset);
}
}

if (tid == 0) {
probs_grad[row_idx] = local_probs_grad;
}
}

std::vector<paddle::Tensor> SwigluProbsGradCUDABackward(
const paddle::Tensor& o1,
const paddle::Tensor& do2_s,
const paddle::Tensor& unzipped_probs) {
auto o1_dims = o1.dims();
const int topk = o1_dims[0];
const int seqlen = o1_dims[1];
const int seq_len_topk = topk * seqlen;
const int moe_intermediate_size_2 = o1_dims[2];
const int moe_intermediate_size = moe_intermediate_size_2 / 2;

auto do1 = paddle::empty_like(o1);
auto probs_grad = paddle::empty(
{o1_dims[0], o1_dims[1]}, paddle::DataType::FLOAT32, o1.place());
auto o2_s = paddle::empty_like(do2_s);

const BFloat16* o1_ptr =
reinterpret_cast<const BFloat16*>(o1.data<phi::bfloat16>());
const BFloat16* do2_s_ptr =
reinterpret_cast<const BFloat16*>(do2_s.data<phi::bfloat16>());
const float* unzipped_probs_ptr = unzipped_probs.data<float>();
BFloat16* do1_ptr = reinterpret_cast<BFloat16*>(do1.data<phi::bfloat16>());
float* probs_grad_ptr = probs_grad.data<float>();
BFloat16* o2_s_ptr = reinterpret_cast<BFloat16*>(o2_s.data<phi::bfloat16>());

constexpr int block_size = 256;
if (moe_intermediate_size % 4 != 0) {
SwigluProbsGradKernel<block_size>
<<<seq_len_topk, block_size, 0, o1.stream()>>>(o1_ptr,
do2_s_ptr,
unzipped_probs_ptr,
do1_ptr,
probs_grad_ptr,
o2_s_ptr,
moe_intermediate_size);
} else {
SwigluProbsGradKernelVec4<block_size>
<<<seq_len_topk, block_size, 0, o1.stream()>>>(o1_ptr,
do2_s_ptr,
unzipped_probs_ptr,
do1_ptr,
probs_grad_ptr,
o2_s_ptr,
moe_intermediate_size);
}


return {do1, probs_grad, o2_s};
}

PD_BUILD_OP(fused_swiglu_probs_bwd)
.Inputs({"o1", "do2_s", "unzipped_probs"})
.Outputs({"do1", "probs_grad", "o2_s"})
.SetKernelFn(PD_KERNEL(SwigluProbsGradCUDABackward));
1 change: 1 addition & 0 deletions slm/model_zoo/gpt-3/external_ops/setup_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def setup_fused_quant_ops():
"fused_quanted_ops/fused_act_quant.cu",
"fused_quanted_ops/fused_act_dequant.cu",
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
"fused_quanted_ops/fused_swiglu_probs_bwd.cu",
"fused_quanted_ops/fused_spaq.cu",
],
extra_compile_args={
Expand Down
Loading
Loading