Skip to content

Commit 1154e07

Browse files
committed
Implementation of FusedPolyNormKernel
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
1 parent ef386df commit 1154e07

File tree

10 files changed

+504
-31
lines changed

10 files changed

+504
-31
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import itertools
5+
from typing import Optional, Union
6+
7+
import torch
8+
9+
from vllm import _custom_ops as vllm_ops
10+
from vllm.triton_utils import triton
11+
12+
13+
def polynorm_naive(
14+
x: torch.Tensor,
15+
weight: torch.Tensor,
16+
bias: torch.Tensor,
17+
eps: float = 1e-6,
18+
):
19+
orig_shape = x.shape
20+
x = x.view(-1, x.shape[-1])
21+
22+
def norm(x, eps: float):
23+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
24+
25+
x = x.float()
26+
return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
27+
weight[2] * norm(x, eps) + bias).to(weight.dtype).view(orig_shape)
28+
29+
30+
def polynorm_vllm(
31+
x: torch.Tensor,
32+
weight: torch.Tensor,
33+
bias: torch.Tensor,
34+
eps: float = 1e-6,
35+
):
36+
orig_shape = x.shape
37+
x = x.view(-1, x.shape[-1])
38+
39+
out = torch.empty_like(x)
40+
vllm_ops.poly_norm(out, x, weight, bias, eps)
41+
output = out
42+
43+
output = output.view(orig_shape)
44+
return output
45+
46+
47+
def calculate_diff(batch_size, seq_len, hidden_size):
48+
dtype = torch.bfloat16
49+
x = torch.randn(batch_size,
50+
seq_len,
51+
hidden_size,
52+
dtype=dtype,
53+
device="cuda")
54+
weight = torch.ones(3, dtype=dtype, device="cuda")
55+
bais = torch.ones(1, dtype=dtype, device="cuda")
56+
57+
output_naive = polynorm_naive(x.clone(), weight, bais)
58+
output_vllm = polynorm_vllm(x.clone(), weight, bais)
59+
60+
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
61+
print("✅ All implementations match")
62+
else:
63+
print("❌ Implementations differ")
64+
65+
66+
batch_size_range = [2**i for i in range(0, 7, 2)]
67+
seq_length_range = [2**i for i in range(6, 11, 1)]
68+
head_num_range = [32, 48]
69+
configs = list(
70+
itertools.product(head_num_range, batch_size_range, seq_length_range))
71+
72+
73+
def get_benchmark():
74+
75+
@triton.testing.perf_report(
76+
triton.testing.Benchmark(
77+
x_names=["head_num", "batch_size", "seq_len"],
78+
x_vals=[list(_) for _ in configs],
79+
line_arg="provider",
80+
line_vals=["naive", "vllm"],
81+
line_names=["Naive", "vLLM"],
82+
styles=[("blue", "-"), ("red", "-")],
83+
ylabel="us",
84+
plot_name=f"polynorm-perf",
85+
args={},
86+
))
87+
def benchmark(head_num, batch_size, seq_len, provider):
88+
dtype = torch.bfloat16
89+
hidden_size = head_num * 128 # assuming head_dim = 128
90+
91+
x = torch.randn(batch_size,
92+
seq_len,
93+
hidden_size,
94+
dtype=dtype,
95+
device="cuda")
96+
weight = torch.ones(3, dtype=dtype, device="cuda")
97+
bias = torch.ones(1, dtype=dtype, device="cuda")
98+
99+
quantiles = [0.5, 0.2, 0.8]
100+
101+
if provider == "naive":
102+
ms, min_ms, max_ms = triton.testing.do_bench(
103+
lambda: polynorm_naive(x.clone(), weight, bias),
104+
quantiles=quantiles,
105+
)
106+
else:
107+
ms, min_ms, max_ms = triton.testing.do_bench(
108+
lambda: polynorm_vllm(x.clone(), weight, bias),
109+
quantiles=quantiles,
110+
)
111+
112+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
113+
114+
return benchmark
115+
116+
117+
if __name__ == "__main__":
118+
import argparse
119+
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--batch-size",
123+
type=int,
124+
default=4,
125+
help="Batch size",
126+
)
127+
parser.add_argument(
128+
"--seq-len",
129+
type=int,
130+
default=128,
131+
help="Sequence length",
132+
)
133+
parser.add_argument(
134+
"--hidden-size",
135+
type=int,
136+
default=4096,
137+
help="Hidden size (2nd dimension) of the sequence",
138+
)
139+
parser.add_argument(
140+
"--save-path",
141+
type=str,
142+
default="./configs/polnorm/",
143+
help="Path to save polnorm benchmark results",
144+
)
145+
146+
args = parser.parse_args()
147+
148+
# Run correctness test
149+
calculate_diff(
150+
batch_size=args.batch_size,
151+
seq_len=args.seq_len,
152+
hidden_size=args.hidden_size,
153+
)
154+
155+
benchmark = get_benchmark()
156+
# Run performance benchmark
157+
benchmark.run(print_data=True, save_path=args.save_path)

0 commit comments

Comments
 (0)