Skip to content

Commit bc5caa4

Browse files
authored
Support various block sizes & Change default block size to 16 (vllm-project#38)
1 parent 52d027d commit bc5caa4

File tree

7 files changed

+602
-619
lines changed

7 files changed

+602
-619
lines changed

benchmark/benchmark_text_completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def get_sampling_dir_name(
268268
f'{model_name}-tp{args.tensor_parallel_size}',
269269
sample_dir,
270270
'cacheflow',
271+
f'block{args.block_size}',
271272
f'req-rate-{args.request_rate}',
272273
f'seed{args.seed}',
273274
f'duration-{args.duration}',

cacheflow/master/block_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ def __init__(
1515
block_size: int,
1616
num_blocks: int,
1717
) -> None:
18-
if block_size not in [8, 16, 32]:
19-
raise ValueError(f'Unsupported block size: {block_size}'
20-
'The block size must be one of {8, 16, 32}.')
2118
self.device = device
2219
self.block_size = block_size
2320
self.num_blocks = num_blocks

cacheflow/master/scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def _schedule(
125125

126126
# Swap in the sequence groups in the SWAPPED state if possible.
127127
self.swapped = self.policy.sort_by_priority(now, self.swapped)
128-
while self.swapped:
128+
# FCFS
129+
while self.swapped and not blocks_to_swap_out:
129130
seq_group = self.swapped[0]
130131
# If the sequence group has been preempted in this step, stop.
131132
if seq_group in preempted:

cacheflow/master/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ def add_server_arguments(parser: argparse.ArgumentParser):
180180
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
181181
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
182182
# KV cache arguments
183-
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size')
183+
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
184184
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
185-
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
185+
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
186186
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
187187
parser.add_argument('--seed', type=int, default=0, help='random seed')
188188
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')

csrc/attention.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,9 @@ void single_query_cached_kv_attention(
1111
int block_size,
1212
int max_context_len);
1313

14-
void multi_query_cached_kv_attention(
15-
torch::Tensor& cu_query_lens,
16-
torch::Tensor& out,
17-
torch::Tensor& query,
18-
torch::Tensor& key_cache,
19-
torch::Tensor& value_cache,
20-
float scale,
21-
torch::Tensor& block_tables,
22-
torch::Tensor& context_lens,
23-
int block_size,
24-
int max_context_len);
25-
2614
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2715
m.def(
2816
"single_query_cached_kv_attention",
2917
&single_query_cached_kv_attention,
3018
"Compute the attention between an input query and the cached key/value tensors");
31-
m.def(
32-
"multi_query_cached_kv_attention",
33-
&multi_query_cached_kv_attention,
34-
"Compute the attention between multiple input queries and the cached key/value tensors");
3519
}

0 commit comments

Comments
 (0)