Skip to content

Support various block sizes #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/benchmark_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def get_sampling_dir_name(
f'{model_name}-tp{args.tensor_parallel_size}',
sample_dir,
'cacheflow',
f'block{args.block_size}',
f'req-rate-{args.request_rate}',
f'seed{args.seed}',
f'duration-{args.duration}',
Expand Down
3 changes: 0 additions & 3 deletions cacheflow/master/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ def __init__(
block_size: int,
num_blocks: int,
) -> None:
if block_size not in [8, 16, 32]:
raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be one of {8, 16, 32}.')
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
Expand Down
3 changes: 2 additions & 1 deletion cacheflow/master/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def _schedule(

# Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped)
while self.swapped:
# FCFS
while self.swapped and not blocks_to_swap_out:
seq_group = self.swapped[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
Expand Down
4 changes: 2 additions & 2 deletions cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size')
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
Expand Down
16 changes: 0 additions & 16 deletions csrc/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,9 @@ void single_query_cached_kv_attention(
int block_size,
int max_context_len);

void multi_query_cached_kv_attention(
torch::Tensor& cu_query_lens,
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"single_query_cached_kv_attention",
&single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors");
m.def(
"multi_query_cached_kv_attention",
&multi_query_cached_kv_attention,
"Compute the attention between multiple input queries and the cached key/value tensors");
}
Loading