Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 16 additions & 155 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace norm {

template <uint32_t VEC_SIZE, typename T>
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
const uint32_t d, float eps) {
const uint32_t d, float weight_bias, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
Expand Down Expand Up @@ -87,7 +87,7 @@ __global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T*
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = float(input_vec[j]) * rms_rcp * float(weight_vec[j]);
output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand All @@ -105,7 +105,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &weight, &output, &d, &eps};
float weight_bias = 0.f;
void* args[] = {&input, &weight, &output, &d, &weight_bias, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = RMSNormKernel<VEC_SIZE, T>;
Expand All @@ -116,7 +117,8 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_

template <uint32_t VEC_SIZE, typename T>
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
T* __restrict__ weight, const uint32_t d, float eps) {
T* __restrict__ weight, const uint32_t d, float weight_bias,
float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
Expand Down Expand Up @@ -187,7 +189,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = x_vec[j] * rms_rcp * float(weight_vec[j]);
input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand All @@ -205,7 +207,8 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = (num_warps + d) * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};
float weight_bias = 0.f;
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
Expand All @@ -215,73 +218,6 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void GemmaRMSNormKernel(T* __restrict__ input, T* __restrict__ weight,
T* __restrict__ output, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
sum_sq += float(input_vec[j]) * float(input_vec[j]);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / (float(d) + eps));

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = float(input_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
Expand All @@ -292,92 +228,16 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &weight, &output, &d, &eps};
float weight_bias = 1.f;
void* args[] = {&input, &weight, &output, &d, &weight_bias, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = GemmaRMSNormKernel<VEC_SIZE, T>;
auto kernel = RMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void GemmaFusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
T* __restrict__ weight, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = float(input_vec[j]);
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / (float(d) + eps));

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
Expand All @@ -387,11 +247,12 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};
const uint32_t smem_size = (num_warps + d) * sizeof(float);
float weight_bias = 1.f;
void* args[] = {&input, &residual, &weight, &d, &weight_bias, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = GemmaFusedAddRMSNormKernel<VEC_SIZE, T>;
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

Expand Down
18 changes: 10 additions & 8 deletions tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@


def llama_rms_norm(x, w, eps=1e-6):
def _norm(x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

output = _norm(x.float()).type_as(x)
return output * w
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x.to(orig_dtype)
return x


def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x

Expand All @@ -45,7 +47,7 @@ def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x, residual

Expand All @@ -58,7 +60,7 @@ def fused_add_rms_norm(x, residual, weight, eps):

variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x.to(orig_dtype) * weight
x = (x * weight.float()).to(orig_dtype)
return x, residual


Expand Down