3
3
import triton
4
4
from packaging import version
5
5
from transformers .models .llama .modeling_llama import LlamaRMSNorm
6
- from vllm .model_executor .layers .layernorm import RMSNorm
7
6
8
7
from colossalai .kernel .triton import rms_layernorm
9
8
from colossalai .testing .utils import parameterize
@@ -36,7 +35,8 @@ def test_layer_norm(M, N):
36
35
y_triton = rms_layernorm (x , weight , eps = eps )
37
36
y_llama = rms_norm .forward (x ).to (dtype )
38
37
39
- assert torch .allclose (y_triton , y_llama , atol = 1e-5 , rtol = 1e-5 )
38
+ assert y_triton .shape == y_llama .shape
39
+ assert torch .allclose (y_triton , y_llama , atol = 1e-5 , rtol = 1e-3 )
40
40
41
41
42
42
# Triton benchmark plot attributions
@@ -45,8 +45,8 @@ def test_layer_norm(M, N):
45
45
x_names = ["SEQUENCE_TOTAL" ],
46
46
x_vals = [i for i in range (128 , 1025 , 128 )],
47
47
line_arg = "provider" ,
48
- line_vals = ["vllm_rms_layernorm " , "triton_rms_layernorm" ],
49
- line_names = ["vllm_rms_layernorm " , "triton_rms_layernorm" ],
48
+ line_vals = ["torch_rms_layernorm " , "triton_rms_layernorm" ],
49
+ line_names = ["torch_rms_layernorm " , "triton_rms_layernorm" ],
50
50
styles = [("red" , "-" ), ("blue" , "-" )],
51
51
ylabel = "ms" ,
52
52
plot_name = f"RMSNorm benchmarking results" ,
@@ -69,10 +69,10 @@ def benchmark_rms_layernorm(
69
69
x_shape = (SEQUENCE_TOTAL , HIDDEN_SIZE )
70
70
w_shape = (x_shape [- 1 ],)
71
71
weight = torch .ones (w_shape , dtype = dtype , device = "cuda" )
72
- vllm_norm = RMSNorm (hidden_size = HIDDEN_SIZE ).to (dtype = dtype , device = "cuda" )
72
+ torch_norm = LlamaRMSNorm (hidden_size = HIDDEN_SIZE ).to (dtype = dtype , device = "cuda" )
73
73
x = - 2.3 + 0.5 * torch .randn (x_shape , dtype = dtype , device = "cuda" )
74
- if provider == "vllm_rms_layernorm " :
75
- fn = lambda : vllm_norm (x )
74
+ if provider == "torch_rms_layernorm " :
75
+ fn = lambda : torch_norm (x )
76
76
elif provider == "triton_rms_layernorm" :
77
77
fn = lambda : rms_layernorm (x , weight , eps = eps )
78
78
else :
0 commit comments