Skip to content

Commit 02222a0

Browse files
ywang96BBuf
andauthored
[Misc] Kernel Benchmark for RMSNorm (vllm-project#11241)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Xiaoyu Zhang <BBuf@users.noreply.github.com>
1 parent 2bfdbf2 commit 02222a0

File tree

1 file changed

+262
-0
lines changed

1 file changed

+262
-0
lines changed
+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import itertools
2+
from typing import Optional, Tuple, Union
3+
4+
import torch
5+
import triton
6+
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
7+
from torch import nn
8+
9+
from vllm import _custom_ops as vllm_ops
10+
11+
12+
class HuggingFaceRMSNorm(nn.Module):
13+
14+
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
15+
super().__init__()
16+
self.weight = nn.Parameter(torch.ones(hidden_size))
17+
self.variance_epsilon = eps
18+
19+
def forward(
20+
self,
21+
x: torch.Tensor,
22+
residual: Optional[torch.Tensor] = None,
23+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
24+
orig_dtype = x.dtype
25+
x = x.to(torch.float32)
26+
if residual is not None:
27+
x = x + residual.to(torch.float32)
28+
residual = x.to(orig_dtype)
29+
30+
variance = x.pow(2).mean(dim=-1, keepdim=True)
31+
x = x * torch.rsqrt(variance + self.variance_epsilon)
32+
x = x.to(orig_dtype) * self.weight
33+
if residual is None:
34+
return x
35+
else:
36+
return x, residual
37+
38+
39+
def rmsnorm_naive(
40+
x: torch.Tensor,
41+
weight: torch.Tensor,
42+
residual: Optional[torch.Tensor] = None,
43+
eps: float = 1e-6,
44+
):
45+
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
46+
naive_norm.weight = nn.Parameter(weight)
47+
naive_norm = naive_norm.to(x.device)
48+
49+
orig_shape = x.shape
50+
x = x.view(-1, x.shape[-1])
51+
if residual is not None:
52+
residual = residual.view(-1, residual.shape[-1])
53+
54+
output = naive_norm(x, residual)
55+
56+
if isinstance(output, tuple):
57+
output = (output[0].view(orig_shape), output[1].view(orig_shape))
58+
else:
59+
output = output.view(orig_shape)
60+
return output
61+
62+
63+
def rmsnorm_flashinfer(
64+
x: torch.Tensor,
65+
weight: torch.Tensor,
66+
residual: Optional[torch.Tensor] = None,
67+
eps: float = 1e-6,
68+
):
69+
orig_shape = x.shape
70+
x = x.view(-1, x.shape[-1])
71+
if residual is not None:
72+
residual = residual.view(-1, residual.shape[-1])
73+
74+
if residual is not None:
75+
fused_add_rmsnorm(x, residual, weight, eps)
76+
output = (x, residual)
77+
else:
78+
output = rmsnorm(x, weight, eps)
79+
80+
if isinstance(output, tuple):
81+
output = (output[0].view(orig_shape), output[1].view(orig_shape))
82+
else:
83+
output = output.view(orig_shape)
84+
return output
85+
86+
87+
def rmsnorm_vllm(
88+
x: torch.Tensor,
89+
weight: torch.Tensor,
90+
residual: Optional[torch.Tensor] = None,
91+
eps: float = 1e-6,
92+
):
93+
orig_shape = x.shape
94+
x = x.view(-1, x.shape[-1])
95+
if residual is not None:
96+
residual = residual.view(-1, residual.shape[-1])
97+
98+
if residual is not None:
99+
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
100+
output = (x, residual)
101+
else:
102+
out = torch.empty_like(x)
103+
vllm_ops.rms_norm(out, x, weight, eps)
104+
output = out
105+
106+
if isinstance(output, tuple):
107+
output = (output[0].view(orig_shape), output[1].view(orig_shape))
108+
else:
109+
output = output.view(orig_shape)
110+
return output
111+
112+
113+
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
114+
dtype = torch.bfloat16
115+
x = torch.randn(batch_size,
116+
seq_len,
117+
hidden_size,
118+
dtype=dtype,
119+
device="cuda")
120+
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
121+
residual = torch.randn_like(x) if use_residual else None
122+
123+
output_naive = rmsnorm_naive(
124+
x.clone(), weight,
125+
residual.clone() if residual is not None else None)
126+
output_flashinfer = rmsnorm_flashinfer(
127+
x.clone(), weight,
128+
residual.clone() if residual is not None else None)
129+
output_vllm = rmsnorm_vllm(
130+
x.clone(), weight,
131+
residual.clone() if residual is not None else None)
132+
133+
if use_residual:
134+
output_naive = output_naive[0]
135+
output_flashinfer = output_flashinfer[0]
136+
output_vllm = output_vllm[0]
137+
138+
print(f"Naive output={output_naive}")
139+
print(f"FlashInfer output={output_flashinfer}")
140+
print(f"VLLM output={output_vllm}")
141+
142+
if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
143+
rtol=1e-2) and torch.allclose(
144+
output_naive, output_vllm, atol=1e-2, rtol=1e-2):
145+
print("✅ All implementations match")
146+
else:
147+
print("❌ Implementations differ")
148+
149+
150+
batch_size_range = [2**i for i in range(0, 7, 2)]
151+
seq_length_range = [2**i for i in range(6, 11, 1)]
152+
head_num_range = [32, 48]
153+
configs = list(
154+
itertools.product(head_num_range, batch_size_range, seq_length_range))
155+
156+
157+
def get_benchmark(use_residual):
158+
159+
@triton.testing.perf_report(
160+
triton.testing.Benchmark(
161+
x_names=["head_num", "batch_size", "seq_len"],
162+
x_vals=[list(_) for _ in configs],
163+
line_arg="provider",
164+
line_vals=["huggingface", "flashinfer", "vllm"],
165+
line_names=["HuggingFace", "FlashInfer", "vLLM"],
166+
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
167+
ylabel="us",
168+
plot_name=
169+
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
170+
args={},
171+
))
172+
def benchmark(head_num, batch_size, seq_len, provider):
173+
dtype = torch.bfloat16
174+
hidden_size = head_num * 128 # assuming head_dim = 128
175+
176+
x = torch.randn(batch_size,
177+
seq_len,
178+
hidden_size,
179+
dtype=dtype,
180+
device="cuda")
181+
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
182+
residual = torch.randn_like(x) if use_residual else None
183+
184+
quantiles = [0.5, 0.2, 0.8]
185+
186+
if provider == "huggingface":
187+
ms, min_ms, max_ms = triton.testing.do_bench(
188+
lambda: rmsnorm_naive(
189+
x.clone(),
190+
weight,
191+
residual.clone() if residual is not None else None,
192+
),
193+
quantiles=quantiles,
194+
)
195+
elif provider == "flashinfer":
196+
ms, min_ms, max_ms = triton.testing.do_bench(
197+
lambda: rmsnorm_flashinfer(
198+
x.clone(),
199+
weight,
200+
residual.clone() if residual is not None else None,
201+
),
202+
quantiles=quantiles,
203+
)
204+
else:
205+
ms, min_ms, max_ms = triton.testing.do_bench(
206+
lambda: rmsnorm_vllm(
207+
x.clone(),
208+
weight,
209+
residual.clone() if residual is not None else None,
210+
),
211+
quantiles=quantiles,
212+
)
213+
214+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
215+
216+
return benchmark
217+
218+
219+
if __name__ == "__main__":
220+
import argparse
221+
222+
parser = argparse.ArgumentParser()
223+
parser.add_argument(
224+
"--batch-size",
225+
type=int,
226+
default=4,
227+
help="Batch size",
228+
)
229+
parser.add_argument(
230+
"--seq-len",
231+
type=int,
232+
default=128,
233+
help="Sequence length",
234+
)
235+
parser.add_argument(
236+
"--hidden-size",
237+
type=int,
238+
default=4096,
239+
help="Hidden size (2nd dimension) of the sequence",
240+
)
241+
parser.add_argument("--use-residual",
242+
action="store_true",
243+
help="Whether to use residual connection")
244+
parser.add_argument(
245+
"--save-path",
246+
type=str,
247+
default="./configs/rmsnorm/",
248+
help="Path to save rmsnorm benchmark results",
249+
)
250+
251+
args = parser.parse_args()
252+
253+
# Run correctness test
254+
calculate_diff(batch_size=args.batch_size,
255+
seq_len=args.seq_len,
256+
hidden_size=args.hidden_size,
257+
use_residual=args.use_residual)
258+
259+
# Get the benchmark function with proper use_residual setting
260+
benchmark = get_benchmark(args.use_residual)
261+
# Run performance benchmark
262+
benchmark.run(print_data=True, save_path=args.save_path)

0 commit comments

Comments
 (0)