22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import itertools
5- from typing import Optional , Union
65
76import torch
87
@@ -23,8 +22,16 @@ def norm(x, eps: float):
2322 return x / torch .sqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + eps )
2423
2524 x = x .float ()
26- return (weight [0 ] * norm (x ** 3 , eps ) + weight [1 ] * norm (x ** 2 , eps ) +
27- weight [2 ] * norm (x , eps ) + bias ).to (weight .dtype ).view (orig_shape )
25+ return (
26+ (
27+ weight [0 ] * norm (x ** 3 , eps )
28+ + weight [1 ] * norm (x ** 2 , eps )
29+ + weight [2 ] * norm (x , eps )
30+ + bias
31+ )
32+ .to (weight .dtype )
33+ .view (orig_shape )
34+ )
2835
2936
3037def polynorm_vllm (
@@ -44,18 +51,14 @@ def polynorm_vllm(
4451 return output
4552
4653
47- def calculate_diff (batch_size , seq_len , hidden_size ):
54+ def calculate_diff (batch_size , seq_len , hidden_dim ):
4855 dtype = torch .bfloat16
49- x = torch .randn (batch_size ,
50- seq_len ,
51- hidden_size ,
52- dtype = dtype ,
53- device = "cuda" )
56+ x = torch .randn (batch_size , seq_len , hidden_dim , dtype = dtype , device = "cuda" )
5457 weight = torch .ones (3 , dtype = dtype , device = "cuda" )
5558 bais = torch .ones (1 , dtype = dtype , device = "cuda" )
5659
57- output_naive = polynorm_naive (x . clone () , weight , bais )
58- output_vllm = polynorm_vllm (x . clone () , weight , bais )
60+ output_naive = polynorm_naive (x , weight , bais )
61+ output_vllm = polynorm_vllm (x , weight , bais )
5962
6063 if torch .allclose (output_naive , output_vllm , atol = 1e-2 , rtol = 1e-2 ):
6164 print ("✅ All implementations match" )
@@ -65,47 +68,42 @@ def calculate_diff(batch_size, seq_len, hidden_size):
6568
6669batch_size_range = [2 ** i for i in range (0 , 7 , 2 )]
6770seq_length_range = [2 ** i for i in range (6 , 11 , 1 )]
68- head_num_range = [32 , 48 ]
69- configs = list (
70- itertools .product (head_num_range , batch_size_range , seq_length_range ))
71+ dim_range = [2048 , 4096 ]
72+ configs = list (itertools .product (dim_range , batch_size_range , seq_length_range ))
7173
7274
7375def get_benchmark ():
74-
7576 @triton .testing .perf_report (
7677 triton .testing .Benchmark (
77- x_names = ["head_num " , "batch_size" , "seq_len" ],
78+ x_names = ["dim " , "batch_size" , "seq_len" ],
7879 x_vals = [list (_ ) for _ in configs ],
7980 line_arg = "provider" ,
8081 line_vals = ["naive" , "vllm" ],
8182 line_names = ["Naive" , "vLLM" ],
8283 styles = [("blue" , "-" ), ("red" , "-" )],
8384 ylabel = "us" ,
84- plot_name = f "polynorm-perf" ,
85+ plot_name = "polynorm-perf" ,
8586 args = {},
86- ))
87- def benchmark (head_num , batch_size , seq_len , provider ):
87+ )
88+ )
89+ def benchmark (dim , batch_size , seq_len , provider ):
8890 dtype = torch .bfloat16
89- hidden_size = head_num * 128 # assuming head_dim = 128
91+ hidden_dim = dim * 4
9092
91- x = torch .randn (batch_size ,
92- seq_len ,
93- hidden_size ,
94- dtype = dtype ,
95- device = "cuda" )
93+ x = torch .randn (batch_size , seq_len , hidden_dim , dtype = dtype , device = "cuda" )
9694 weight = torch .ones (3 , dtype = dtype , device = "cuda" )
9795 bias = torch .ones (1 , dtype = dtype , device = "cuda" )
9896
9997 quantiles = [0.5 , 0.2 , 0.8 ]
10098
10199 if provider == "naive" :
102100 ms , min_ms , max_ms = triton .testing .do_bench (
103- lambda : polynorm_naive (x . clone () , weight , bias ),
101+ lambda : polynorm_naive (x , weight , bias ),
104102 quantiles = quantiles ,
105103 )
106104 else :
107105 ms , min_ms , max_ms = triton .testing .do_bench (
108- lambda : polynorm_vllm (x . clone () , weight , bias ),
106+ lambda : polynorm_vllm (x , weight , bias ),
109107 quantiles = quantiles ,
110108 )
111109
@@ -131,10 +129,10 @@ def benchmark(head_num, batch_size, seq_len, provider):
131129 help = "Sequence length" ,
132130 )
133131 parser .add_argument (
134- "--hidden-size " ,
132+ "--hidden-dim " ,
135133 type = int ,
136- default = 4096 ,
137- help = "Hidden size (2nd dimension) of the sequence " ,
134+ default = 8192 ,
135+ help = "Intermediate size of MLP " ,
138136 )
139137 parser .add_argument (
140138 "--save-path" ,
@@ -149,7 +147,7 @@ def benchmark(head_num, batch_size, seq_len, provider):
149147 calculate_diff (
150148 batch_size = args .batch_size ,
151149 seq_len = args .seq_len ,
152- hidden_size = args .hidden_size ,
150+ hidden_dim = args .hidden_dim ,
153151 )
154152
155153 benchmark = get_benchmark ()
0 commit comments