1
+ from typing import Optional
1
2
import argparse
2
3
import random
3
4
import time
4
5
5
6
import torch
6
7
8
+ from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , create_kv_caches_with_random
7
9
from vllm ._C import ops
8
10
9
11
NUM_BLOCKS = 1024
@@ -23,6 +25,7 @@ def main(
23
25
dtype : torch .dtype ,
24
26
seed : int ,
25
27
do_profile : bool ,
28
+ kv_cache_dtype : Optional [str ] = None ,
26
29
) -> None :
27
30
random .seed (seed )
28
31
torch .random .manual_seed (seed )
@@ -59,15 +62,10 @@ def main(
59
62
block_tables = torch .tensor (block_tables , dtype = torch .int , device = "cuda" )
60
63
61
64
# Create the KV cache.
62
- x = 16 // torch .tensor ([], dtype = dtype ).element_size ()
63
- key_cache_shape = (NUM_BLOCKS , num_kv_heads , head_size // x , block_size , x )
64
- key_cache = torch .empty (size = key_cache_shape , dtype = dtype , device = "cuda" )
65
- key_cache .uniform_ (- scale , scale )
66
- value_cache_shape = (NUM_BLOCKS , num_kv_heads , head_size , block_size )
67
- value_cache = torch .empty (size = value_cache_shape ,
68
- dtype = dtype ,
69
- device = "cuda" )
70
- value_cache .uniform_ (- scale , scale )
65
+ key_caches , value_caches = create_kv_caches_with_random (
66
+ NUM_BLOCKS , block_size , 1 , num_kv_heads , head_size , kv_cache_dtype ,
67
+ dtype )
68
+ key_cache , value_cache = key_caches [0 ], value_caches [0 ]
71
69
72
70
# Prepare for the paged attention kernel.
73
71
output = torch .empty_like (query )
@@ -106,6 +104,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
106
104
block_size ,
107
105
max_context_len ,
108
106
alibi_slopes ,
107
+ kv_cache_dtype ,
109
108
)
110
109
elif version == "v2" :
111
110
ops .paged_attention_v2 (
@@ -123,6 +122,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
123
122
block_size ,
124
123
max_context_len ,
125
124
alibi_slopes ,
125
+ kv_cache_dtype ,
126
126
)
127
127
else :
128
128
raise ValueError (f"Invalid version: { version } " )
@@ -168,16 +168,18 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
168
168
default = "half" )
169
169
parser .add_argument ("--seed" , type = int , default = 0 )
170
170
parser .add_argument ("--profile" , action = "store_true" )
171
+ parser .add_argument (
172
+ "--kv-cache-dtype" ,
173
+ type = str ,
174
+ choices = ["auto" , "fp8_e5m2" ],
175
+ default = "auto" ,
176
+ help =
177
+ 'Data type for kv cache storage. If "auto", will use model data type.' )
171
178
args = parser .parse_args ()
172
179
print (args )
173
180
174
181
if args .num_query_heads % args .num_kv_heads != 0 :
175
182
raise ValueError ("num_query_heads must be divisible by num_kv_heads" )
176
- dtype_to_torch_dtype = {
177
- "half" : torch .half ,
178
- "bfloat16" : torch .bfloat16 ,
179
- "float" : torch .float ,
180
- }
181
183
main (
182
184
version = args .version ,
183
185
num_seqs = args .batch_size ,
@@ -187,7 +189,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
187
189
head_size = args .head_size ,
188
190
block_size = args .block_size ,
189
191
use_alibi = args .use_alibi ,
190
- dtype = dtype_to_torch_dtype [args .dtype ],
192
+ dtype = STR_DTYPE_TO_TORCH_DTYPE [args .dtype ],
191
193
seed = args .seed ,
192
194
do_profile = args .profile ,
195
+ kv_cache_dtype = args .kv_cache_dtype ,
193
196
)
0 commit comments