Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Enable chunked prefill and prefix caching together #7753

Merged
merged 8 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

Run `pytest tests/models/test_chunked_prefill.py`.
"""
from contextlib import nullcontext

import pytest

Expand Down Expand Up @@ -150,3 +151,68 @@ def test_models_with_fp8_kv_cache(
name_0="no_chunked_prefill",
name_1="chunked_prefill",
)


@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_with_prefix_caching(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
use_v2_block_manager: bool,
tensor_parallel_size: int,
) -> None:
"""
Checks exact match decode with and without prefix caching
with chunked prefill enabled.
"""
model = "meta-llama/Llama-2-7b-chat-hf"
# The common prompt has 142 tokens with Llama-2 tokenizer.
common_prompt = "You are a helpful AI assistant " * 20
unique_prompts = [
"Question", # Warmup
"Question", # Fully cached
"Another question", # Partial cached
]
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]

max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore
check_result = True
for enable in (True, False):
with vllm_runner(
model,
dtype="half",
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
tensor_parallel_size=tensor_parallel_size,
use_v2_block_manager=use_v2_block_manager,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as vllm_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)

# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
40 changes: 40 additions & 0 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():

# assert all blocks are free now
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks


def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have corresponding test in v2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to test v2 because v2 automatically mark touched blocks as computed.

"""When prefix cache and chunked prefill are enabled, the block manager
should only mark a chunk of blocks as computed instead of all blocks.
"""

block_size = 4
num_cpu_blocks = 0
num_gpu_blocks = 16
block_manager = BlockSpaceManagerV1(block_size,
num_gpu_blocks,
num_cpu_blocks,
watermark=0,
enable_caching=True)

# Set prompt size to have num_gpu_blocks - 1 full blocks.
prompt_length = block_size * num_gpu_blocks - 1

# Allocate (reserve) all blocks.
_, seq_group = create_dummy_prompt("0",
prompt_length,
block_size=block_size)
block_manager.allocate(seq_group)
assert seq_group.seqs[0].n_blocks == num_gpu_blocks

# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
token_chunk_size = int(block_size * 2.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 2

# Actual computed tokens.
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)

# 2nd chunk: Complete 3rd block and additional 4 blocks.
token_chunk_size = int(block_size * 4.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 7
39 changes: 39 additions & 0 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
assert len(get_sequence_groups(out)) == max_seqs
assert not running[0].is_prefill()
assert not running[1].is_prefill()


def test_perfix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)

seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 50
# Verify it is chunked. Note that although the budget is 64-50=14,
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
# tokens are allocated.
assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62
19 changes: 13 additions & 6 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,20 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_full_blocks_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1

comaniac marked this conversation as resolved.
Show resolved Hide resolved
# When chunked prefill is enabled, the computed full blocks
# should be calculated based on the number of computed tokens.
max_computed_tokens = (seq.data.get_num_computed_tokens() +
token_chunk_size)
computed_full_blocks = max_computed_tokens // self.block_size

block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
if computed_full_blocks == 0:
return
for i in reversed(range(max_full_block)):
for i in reversed(range(computed_full_blocks)):
if block_table[i].computed:
break
block_table[i].computed = True
Expand Down Expand Up @@ -717,10 +723,11 @@ def get_common_computed_block_ids(
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
self.compute_full_blocks_in_seq(seq, token_chunk_size)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
# If prefix caching is enabled, mark immutable blocks as computed
# right after they have been scheduled (for prefill). This assumes
# the scheduler is synchronous so blocks are actually computed when
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def get_common_computed_block_ids(
pass

@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

@abstractmethod
Expand Down
30 changes: 24 additions & 6 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# will crash the vLLM instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)

scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
Expand Down Expand Up @@ -1347,10 +1348,27 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1:
num_new_tokens = min(num_new_tokens,
budget.remaining_token_budget())
remaining_token_budget = budget.remaining_token_budget()
comaniac marked this conversation as resolved.
Show resolved Hide resolved
if self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
block_size = self.cache_config.block_size
reminder = budget.token_budget % block_size
if reminder != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, should we raise this exception at the engine start time instead and just add assert here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we could just raise here for now because this constraint should be able to be removed once we refactor the schedule to consider prefix caching.

raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {reminder}")
if remaining_token_budget < num_new_tokens:
comaniac marked this conversation as resolved.
Show resolved Hide resolved
num_new_tokens = (remaining_token_budget //
block_size) * block_size
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens
49 changes: 37 additions & 12 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,23 +499,48 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")

# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size

if not prefix_cache_hit:
return

assert computed_block_nums is not None
# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
comaniac marked this conversation as resolved.
Show resolved Hide resolved
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
uncomputed_start = prefix_cache_len - context_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
context_len = prefix_cache_len

inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1

def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
Expand Down
Loading