Skip to content

Commit 4fed528

Browse files
feat rmsnorm cuda kernel and add unittest, benchmark script
1 parent 0aa27f1 commit 4fed528

File tree

7 files changed

+238
-40
lines changed

7 files changed

+238
-40
lines changed

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
LlamaForCausalLM,
1010
LlamaMLP,
1111
LlamaModel,
12+
LlamaRMSNorm,
1213
)
1314

1415
from colossalai.inference.batch_bucket import BatchBucket
@@ -19,6 +20,7 @@
1920
decoding_fused_rotary_embedding,
2021
flash_decoding_attention,
2122
get_xine_cache,
23+
rms_layernorm,
2224
rotary_embedding,
2325
)
2426
from colossalai.logging import get_dist_logger
@@ -124,7 +126,7 @@ def llama_model_forward(
124126
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
125127
residual = residual[last_token_indexs - 1].contiguous()
126128
norm_output = torch.empty_like(hidden_states)
127-
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
129+
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
128130

129131
return hidden_states
130132

@@ -167,7 +169,7 @@ def llama_decoder_layer_forward(
167169
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
168170
"""
169171

170-
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
172+
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
171173
# Self Attention
172174
hidden_states = self.self_attn(
173175
hidden_states=hidden_states,
@@ -185,12 +187,32 @@ def llama_decoder_layer_forward(
185187
)
186188

187189
# Fully Connected
188-
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
190+
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
189191
hidden_states = self.mlp(hidden_states)
190192

191193
return hidden_states, residual
192194

193195

196+
def llama_rmsnorm_forward(
197+
self: LlamaRMSNorm,
198+
hidden_states: torch.Tensor,
199+
norm_output: torch.Tensor,
200+
residual: torch.Tensor = None,
201+
use_cuda_kernel: bool = True,
202+
):
203+
if use_cuda_kernel:
204+
if residual is not None:
205+
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
206+
return hidden_states, residual
207+
208+
if norm_output is None:
209+
norm_output = torch.empty_like(hidden_states)
210+
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
211+
return norm_output, hidden_states
212+
else:
213+
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
214+
215+
194216
class NopadLlamaAttention(LlamaAttention):
195217
def __init__(
196218
self,

colossalai/inference/modeling/policy/nopadding_llama.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import partial
22

3-
import torch
43
from torch.nn import Parameter
54
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
65

@@ -10,34 +9,14 @@
109
llama_causal_lm_forward,
1110
llama_decoder_layer_forward,
1211
llama_model_forward,
12+
llama_rmsnorm_forward,
1313
)
1414
from colossalai.inference.utils import init_to_get_rotary
1515
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
1616

1717
# import colossalai
1818
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
1919

20-
try:
21-
from colossalai.kernel.triton import rms_layernorm
22-
23-
HAS_TRITON_RMSNORM = True
24-
except:
25-
print("you should install triton from https://github.com/openai/triton")
26-
HAS_TRITON_RMSNORM = False
27-
28-
29-
def get_triton_rmsnorm_forward():
30-
if HAS_TRITON_RMSNORM:
31-
32-
def _triton_rmsnorm_forward(
33-
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
34-
):
35-
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
36-
37-
return _triton_rmsnorm_forward
38-
else:
39-
return None
40-
4120

4221
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
4322
def __init__(self) -> None:
@@ -84,15 +63,9 @@ def module_policy(self):
8463
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
8564
)
8665

87-
infer_forward = None
88-
if HAS_TRITON_RMSNORM:
89-
infer_forward = get_triton_rmsnorm_forward()
90-
91-
if infer_forward is not None:
92-
method_replacement = {"forward": partial(infer_forward)}
93-
self.append_or_create_method_replacement(
94-
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
95-
)
66+
infer_forward = llama_rmsnorm_forward
67+
method_replacement = {"forward": partial(infer_forward)}
68+
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
9669

9770
return policy
9871

examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py renamed to examples/inference/benchmark_ops/benchmark_rmsnorm.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch
2-
import triton
32

3+
from colossalai.kernel.kernel_loader import InferenceOpsLoader
44
from colossalai.kernel.triton import rms_layernorm
55

66
try:
77
import triton # noqa
8-
98
except ImportError:
109
print("please install triton from https://github.com/openai/triton")
1110

11+
inference_ops = InferenceOpsLoader().load()
1212

1313
# Triton benchmark plot attributions
1414
configs = [
@@ -19,16 +19,20 @@
1919
line_vals=[
2020
"vllm_rms_layernorm",
2121
"triton_rms_layernorm",
22-
"triton_rms_layernorm_with_residual",
22+
"cuda_rms_layernorm",
2323
"vllm_rms_layernorm_with_residual",
24+
"triton_rms_layernorm_with_residual",
25+
"cuda_rms_layernorm_with_residual",
2426
],
2527
line_names=[
2628
"vllm_rms_layernorm",
2729
"triton_rms_layernorm",
28-
"triton_rms_layernorm_with_residual",
30+
"cuda_rms_layernorm",
2931
"vllm_rms_layernorm_with_residual",
32+
"triton_rms_layernorm_with_residual",
33+
"cuda_rms_layernorm_with_residual",
3034
],
31-
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
35+
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
3236
ylabel="ms",
3337
plot_name=f"RMSNorm benchmarking results",
3438
args={"HIDDEN_SIZE": 1024},
@@ -62,10 +66,15 @@ def benchmark_rms_layernorm(
6266
fn = lambda: vllm_norm(x)
6367
elif provider == "triton_rms_layernorm":
6468
fn = lambda: rms_layernorm(x, weight, eps=eps)
69+
elif provider == "cuda_rms_layernorm":
70+
out = torch.empty_like(x)
71+
fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps)
6572
elif provider == "vllm_rms_layernorm_with_residual":
6673
fn = lambda: vllm_norm(x, residual=residual)
6774
elif provider == "triton_rms_layernorm_with_residual":
6875
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
76+
elif provider == "cuda_rms_layernorm_with_residual":
77+
fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
6978
else:
7079
raise ValueError("Undefined provider.")
7180

extensions/csrc/cuda/colossal_inference_C_frontend.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,23 @@ void decode_kv_cache_memcpy(
99
torch::Tensor& sequence_lengths, // [batch_size]
1010
torch::Tensor& block_tables); // [batch_size, max_seq_len]
1111

12+
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
13+
torch::Tensor& input, // [..., hidden_size]
14+
torch::Tensor& weight, // [hidden_size]
15+
float epsilon);
16+
17+
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
18+
torch::Tensor& residual, // [..., hidden_size]
19+
torch::Tensor& weight, // [hidden_size]
20+
float epsilon);
21+
1222
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1323
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
1424
"Copy the GPU memory of kvcache during the decode stage.");
25+
26+
m.def("rms_layernorm", &rms_layernorm,
27+
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
28+
29+
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
30+
"In-place fused Add and RMS Normalization.");
1531
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*This code from VLLM:
2+
* https://github.com/vllm-project/vllm/
3+
* with minor changes. */
4+
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <torch/extension.h>
7+
#include <c10/cuda/CUDAGuard.h>
8+
#include <stdio.h>
9+
10+
11+
#include "block_reduce.h"
12+
#include "type_shim.h"
13+
14+
template<typename scalar_t>
15+
__global__ void rms_layernorm_kernel(
16+
scalar_t* __restrict__ out, // [..., hidden_size]
17+
const scalar_t* __restrict__ input, // [..., hidden_size]
18+
const scalar_t* __restrict__ weight, // [hidden_size]
19+
const float epsilon,
20+
const int num_tokens,
21+
const int hidden_size) {
22+
__shared__ float s_variance;
23+
float variance = 0.0f;
24+
/*
25+
* since the open-sourced LLM's hidden dimensions mainly range from
26+
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
27+
* hidden dimension limit to 8192, and each thread's capacity
28+
* for caching input tensors to 8 (8192 = 8 * 1024) which
29+
* will cause problems for extremely large models, such as
30+
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
31+
*/
32+
float x_local[8];
33+
34+
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
35+
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
36+
variance += x_local[cnt] * x_local[cnt];
37+
}
38+
variance = blockReduceSum<float>(variance);
39+
if (threadIdx.x == 0) {
40+
s_variance = rsqrtf(variance / hidden_size + epsilon);
41+
}
42+
__syncthreads();
43+
44+
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
45+
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
46+
}
47+
}
48+
49+
template<typename scalar_t>
50+
__global__ void fused_add_rms_layernorm_kernel(
51+
scalar_t* __restrict__ input, // [..., hidden_size]
52+
scalar_t* __restrict__ residual, // [..., hidden_size]
53+
const scalar_t* __restrict__ weight, // [hidden_size]
54+
const float epsilon,
55+
const int num_tokens,
56+
const int hidden_size) {
57+
__shared__ float s_variance;
58+
float variance = 0.0f;
59+
float x_local[8];
60+
61+
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
62+
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
63+
x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx];
64+
variance += x_local[cnt] * x_local[cnt];
65+
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt];
66+
}
67+
variance = blockReduceSum<float>(variance);
68+
if (threadIdx.x == 0) {
69+
s_variance = rsqrtf(variance / hidden_size + epsilon);
70+
}
71+
__syncthreads();
72+
73+
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
74+
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
75+
}
76+
}
77+
78+
void rms_layernorm(
79+
torch::Tensor& out, // [..., hidden_size]
80+
torch::Tensor& input, // [..., hidden_size]
81+
torch::Tensor& weight, // [hidden_size]
82+
float epsilon) {
83+
int hidden_size = input.size(-1);
84+
int num_tokens = input.numel() / hidden_size;
85+
86+
dim3 grid(num_tokens);
87+
dim3 block(std::min(hidden_size, 1024));
88+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
89+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
90+
91+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
92+
input.scalar_type(),
93+
"rms_layernorm_kernel",
94+
rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
95+
out.data_ptr<scalar_t>(),
96+
input.data_ptr<scalar_t>(),
97+
weight.data_ptr<scalar_t>(),
98+
epsilon,
99+
num_tokens,
100+
hidden_size);)
101+
}
102+
103+
void fused_add_rms_layernorm(
104+
torch::Tensor& input, // [..., hidden_size]
105+
torch::Tensor& residual, // [..., hidden_size]
106+
torch::Tensor& weight, // [hidden_size]
107+
float epsilon) {
108+
int hidden_size = input.size(-1);
109+
int num_tokens = input.numel() / hidden_size;
110+
111+
dim3 grid(num_tokens);
112+
dim3 block(std::min(hidden_size, 1024));
113+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
114+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
115+
116+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
117+
input.scalar_type(),
118+
"fused_add_rms_layernorm_kernel",
119+
fused_add_rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
120+
input.data_ptr<scalar_t>(),
121+
residual.data_ptr<scalar_t>(),
122+
weight.data_ptr<scalar_t>(),
123+
epsilon,
124+
num_tokens,
125+
hidden_size);)
126+
}

extensions/inference/inference_ops_cuda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ def sources_files(self):
1212
for fname in [
1313
"cuda/colossal_inference_C_frontend.cpp",
1414
"cuda/decode_kv_cache_memcpy_kernel.cu",
15+
"cuda/rms_layernorm_kernel.cu",
1516
]
1617
]
1718
return ret
1819

1920
def include_dirs(self):
20-
ret = [self.get_cuda_home_include()]
21+
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
2122
return ret
2223

2324
def cxx_flags(self):

0 commit comments

Comments
 (0)