Skip to content

Commit 8aa5fd8

Browse files
committed
use torch for rms_norm benchmarking
1 parent 31b905e commit 8aa5fd8

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/test_infer_ops/triton/test_rmsnorm_triton.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import triton
44
from packaging import version
55
from transformers.models.llama.modeling_llama import LlamaRMSNorm
6-
from vllm.model_executor.layers.layernorm import RMSNorm
76

87
from colossalai.kernel.triton import rms_layernorm
98
from colossalai.testing.utils import parameterize
@@ -36,7 +35,8 @@ def test_layer_norm(M, N):
3635
y_triton = rms_layernorm(x, weight, eps=eps)
3736
y_llama = rms_norm.forward(x).to(dtype)
3837

39-
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
38+
assert y_triton.shape == y_llama.shape
39+
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
4040

4141

4242
# Triton benchmark plot attributions
@@ -45,8 +45,8 @@ def test_layer_norm(M, N):
4545
x_names=["SEQUENCE_TOTAL"],
4646
x_vals=[i for i in range(128, 1025, 128)],
4747
line_arg="provider",
48-
line_vals=["vllm_rms_layernorm", "triton_rms_layernorm"],
49-
line_names=["vllm_rms_layernorm", "triton_rms_layernorm"],
48+
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
49+
line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
5050
styles=[("red", "-"), ("blue", "-")],
5151
ylabel="ms",
5252
plot_name=f"RMSNorm benchmarking results",
@@ -69,10 +69,10 @@ def benchmark_rms_layernorm(
6969
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
7070
w_shape = (x_shape[-1],)
7171
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
72-
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
72+
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
7373
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
74-
if provider == "vllm_rms_layernorm":
75-
fn = lambda: vllm_norm(x)
74+
if provider == "torch_rms_layernorm":
75+
fn = lambda: torch_norm(x)
7676
elif provider == "triton_rms_layernorm":
7777
fn = lambda: rms_layernorm(x, weight, eps=eps)
7878
else:

0 commit comments

Comments
 (0)