Skip to content

Commit 09e9245

Browse files
authored
Add custom kernel for RMS normalization (#16)
1 parent c45f3c3 commit 09e9245

File tree

9 files changed

+243
-58
lines changed

9 files changed

+243
-58
lines changed

cacheflow/models/layernorm.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from cacheflow import layernorm_ops
5+
6+
7+
class RMSNorm(nn.Module):
8+
9+
def __init__(
10+
self,
11+
hidden_size: int,
12+
eps: float = 1e-6,
13+
) -> None:
14+
super().__init__()
15+
self.weight = nn.Parameter(torch.ones(hidden_size))
16+
self.variance_epsilon = eps
17+
18+
def forward(self, x: torch.Tensor) -> torch.Tensor:
19+
out = torch.empty_like(x)
20+
layernorm_ops.rms_norm(
21+
out,
22+
x,
23+
self.weight.data,
24+
self.variance_epsilon,
25+
)
26+
return out

cacheflow/models/llama.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from cacheflow.models import InputMetadata
1414
from cacheflow.models.attention import LlamaCacheFlowAttention
15+
from cacheflow.models.layernorm import RMSNorm
1516
from cacheflow.models.sample import Sampler
1617
from cacheflow.parallel_utils.parallel_state import (
1718
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@@ -23,22 +24,6 @@
2324
KVCache = Tuple[torch.Tensor, torch.Tensor]
2425

2526

26-
class LlamaRMSNorm(nn.Module):
27-
28-
def __init__(self, hidden_size, eps=1e-6):
29-
super().__init__()
30-
self.weight = nn.Parameter(torch.ones(hidden_size))
31-
self.variance_epsilon = eps
32-
33-
def forward(self, hidden_states):
34-
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
36-
# convert into half-precision if necessary
37-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
38-
hidden_states = hidden_states.to(self.weight.dtype)
39-
return self.weight * hidden_states
40-
41-
4227
class LlamaMLP(nn.Module):
4328

4429
def __init__(
@@ -148,8 +133,8 @@ def __init__(self, config: LlamaConfig):
148133
intermediate_size=config.intermediate_size,
149134
hidden_act=config.hidden_act,
150135
)
151-
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
152-
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
136+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
137+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
153138

154139
def forward(
155140
self,
@@ -190,7 +175,7 @@ def __init__(self, config: LlamaConfig):
190175
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
191176
perform_initialization=False)
192177
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
193-
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
178+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
194179

195180
def forward(
196181
self,

csrc/attention_kernels.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "attention_utils.h"
55
#include "cuda_primitives.h"
6+
#include "reduction_utils.h"
67

78
#include <algorithm>
89

csrc/attention_utils.h

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -159,45 +159,6 @@ struct Qk_dot<uint16_t, 4> {
159159
}
160160
};
161161

162-
////////////////////////////////////////////////////////////////////////////////////////////////////
163-
164-
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
165-
inline __device__ float block_sum(float* red_smem, float sum)
166-
{
167-
168-
// Decompose the thread index into warp / lane.
169-
int warp = threadIdx.x / WARP_SIZE;
170-
int lane = threadIdx.x % WARP_SIZE;
171-
172-
// Compute the sum per warp.
173-
#pragma unroll
174-
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
175-
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
176-
}
177-
178-
// Warp leaders store the data to shared memory.
179-
if (lane == 0) {
180-
red_smem[warp] = sum;
181-
}
182-
183-
// Make sure the data is in shared memory.
184-
__syncthreads();
185-
186-
// The warps compute the final sums.
187-
if (lane < WARPS_PER_BLOCK) {
188-
sum = red_smem[lane];
189-
}
190-
191-
// Parallel reduction inside the warp.
192-
#pragma unroll
193-
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
194-
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
195-
}
196-
197-
// Broadcast to other threads.
198-
return __shfl_sync(uint32_t(-1), sum, 0);
199-
}
200-
201162
} // namespace cacheflow
202163

203164
#undef MMHA_USE_FP32_ACUM_FOR_FMA

csrc/layernorm.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <torch/extension.h>
2+
3+
void rms_norm(
4+
torch::Tensor& out,
5+
torch::Tensor& input,
6+
torch::Tensor& weight,
7+
float epsilon);
8+
9+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10+
m.def(
11+
"rms_norm",
12+
&rms_norm,
13+
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
14+
}

csrc/layernorm_kernels.cu

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include <torch/extension.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
4+
#include "reduction_utils.h"
5+
6+
namespace cacheflow {
7+
8+
// TODO(woosuk): Further optimize this kernel.
9+
template<typename scalar_t>
10+
__global__ void rms_norm_kernel(
11+
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
12+
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
13+
const scalar_t* __restrict__ weight, // [hidden_size]
14+
const float epsilon,
15+
const int num_tokens,
16+
const int hidden_size) {
17+
__shared__ float s_variance;
18+
float variance = 0.0f;
19+
20+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
21+
const float x = (float) input[blockIdx.x * hidden_size + idx];
22+
variance += x * x;
23+
}
24+
variance = blockReduceSum<float>(variance);
25+
if (threadIdx.x == 0) {
26+
s_variance = rsqrtf(variance / hidden_size + epsilon);
27+
}
28+
__syncthreads();
29+
30+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
31+
float x = (float) input[blockIdx.x * hidden_size + idx];
32+
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
33+
}
34+
}
35+
36+
} // namespace cacheflow
37+
38+
void rms_norm(
39+
torch::Tensor& out, // [num_tokens, hidden_size]
40+
torch::Tensor& input, // [num_tokens, hidden_size]
41+
torch::Tensor& weight, // [hidden_size]
42+
float epsilon) {
43+
int num_tokens = input.size(0);
44+
int hidden_size = input.size(1);
45+
46+
dim3 grid(num_tokens);
47+
dim3 block(std::min(hidden_size, 1024));
48+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
49+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
50+
input.scalar_type(),
51+
"rms_norm_kernel",
52+
[&] {
53+
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
54+
out.data_ptr<scalar_t>(),
55+
input.data_ptr<scalar_t>(),
56+
weight.data_ptr<scalar_t>(),
57+
epsilon,
58+
num_tokens,
59+
hidden_size);
60+
});
61+
}

csrc/reduction_utils.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#pragma once
2+
3+
namespace cacheflow {
4+
5+
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
6+
inline __device__ float block_sum(float* red_smem, float sum)
7+
{
8+
9+
// Decompose the thread index into warp / lane.
10+
int warp = threadIdx.x / WARP_SIZE;
11+
int lane = threadIdx.x % WARP_SIZE;
12+
13+
// Compute the sum per warp.
14+
#pragma unroll
15+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
16+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
17+
}
18+
19+
// Warp leaders store the data to shared memory.
20+
if (lane == 0) {
21+
red_smem[warp] = sum;
22+
}
23+
24+
// Make sure the data is in shared memory.
25+
__syncthreads();
26+
27+
// The warps compute the final sums.
28+
if (lane < WARPS_PER_BLOCK) {
29+
sum = red_smem[lane];
30+
}
31+
32+
// Parallel reduction inside the warp.
33+
#pragma unroll
34+
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
35+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
36+
}
37+
38+
// Broadcast to other threads.
39+
return __shfl_sync(uint32_t(-1), sum, 0);
40+
}
41+
42+
#define FINAL_MASK 0xffffffff
43+
44+
template<typename T>
45+
__inline__ __device__ T warpReduceSum(T val)
46+
{
47+
#pragma unroll
48+
for (int mask = 16; mask > 0; mask >>= 1)
49+
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
50+
return val;
51+
}
52+
53+
/* Calculate the sum of all elements in a block */
54+
template<typename T>
55+
__inline__ __device__ T blockReduceSum(T val)
56+
{
57+
static __shared__ T shared[32];
58+
int lane = threadIdx.x & 0x1f;
59+
int wid = threadIdx.x >> 5;
60+
61+
val = warpReduceSum<T>(val);
62+
63+
if (lane == 0)
64+
shared[wid] = val;
65+
66+
__syncthreads();
67+
68+
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
69+
// blockDim.x is not divided by 32
70+
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
71+
val = warpReduceSum<T>(val);
72+
73+
return val;
74+
}
75+
76+
} // namespace cacheflow

setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@
3131
)
3232
ext_modules.append(positional_encoding_extension)
3333

34+
# Layer normalization kernels.
35+
layernorm_extension = cpp_extension.CUDAExtension(
36+
name='cacheflow.layernorm_ops',
37+
sources=['csrc/layernorm.cpp', 'csrc/layernorm_kernels.cu'],
38+
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
39+
)
40+
ext_modules.append(layernorm_extension)
41+
3442
setuptools.setup(
3543
name='cacheflow',
3644
ext_modules=ext_modules,

tests/kernels/layernorm.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from cacheflow import layernorm_ops
5+
6+
7+
class RefRMSNorm(nn.Module):
8+
9+
def __init__(self, hidden_size, eps=1e-6):
10+
super().__init__()
11+
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
12+
self.weight = nn.Parameter(weight)
13+
self.variance_epsilon = eps
14+
15+
def forward(self, hidden_states):
16+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
17+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
18+
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
19+
hidden_states = hidden_states.to(self.weight.dtype)
20+
return self.weight * hidden_states
21+
22+
23+
@torch.inference_mode()
24+
def test_rms_norm(
25+
num_tokens: int,
26+
hidden_size: int,
27+
dtype: torch.dtype,
28+
) -> None:
29+
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
30+
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
31+
32+
out = torch.empty_like(x)
33+
layernorm_ops.rms_norm(
34+
out,
35+
x,
36+
ref.weight.data,
37+
ref.variance_epsilon,
38+
)
39+
ref_out = ref(x)
40+
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
41+
42+
43+
if __name__ == '__main__':
44+
for dtype in [torch.half, torch.float]:
45+
for num_tokens in [7, 128, 2048]:
46+
for hidden_size in [13, 64, 1024, 5120]:
47+
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
48+
f'{num_tokens}, hidden_size={hidden_size}')
49+
test_rms_norm(
50+
num_tokens=num_tokens,
51+
hidden_size=hidden_size,
52+
dtype=dtype,
53+
)

0 commit comments

Comments
 (0)