Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix auto prefix bug #3239

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/engine/test_computed_prefix_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
def test_computed_prefix_blocks(model: str, block_size: int):
# This test checks if we are able to run the engine to completion
# without triggering asserts.
# We are in a scenario where all blocks from the second request's prompt
# are full and already computed when the second request arrives.
prompt = (
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")
prompt2 = (
" Please recommend to me some resources where I can learn not only to "
"handle technical difficulties of building a car, but also "
"decoration.")

engine_args = EngineArgs(model=model,
block_size=block_size,
enable_prefix_caching=True)

engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams()

engine.add_request("0", prompt + prompt2, sampling_params)
engine.step()
engine.add_request("1", prompt, sampling_params)
engine.step()
28 changes: 16 additions & 12 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A block manager that manages token blocks."""
import enum
from itertools import count
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -426,23 +426,29 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_last_full_block_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
block_table[max_full_block].computed = True
for i in reversed(range(max_full_block)):
if block_table[i].computed:
break
block_table[i].computed = True

def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables:
return []
block_table = self.block_tables[seq.seq_id]
for block_idx in reversed(range(len(block_table))):
if block_table[block_idx].computed:
return [b.block_number for b in block_table[:block_idx + 1]]
return []
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
return [
b.block_number
for b in takewhile(lambda b: b.computed, block_table[:-1])
]

def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
Expand All @@ -451,14 +457,12 @@ def get_common_computed_block_ids(self,
return []

ids_list = [
self.get_all_block_ids_till_computed(seq)
self.get_all_computed_blocks(seq)
for seq in iter(seq_group.seqs_dict.values())
]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# NOTE: We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_last_full_block_in_seq(seq)
self.compute_full_blocks_in_seq(seq)
1 change: 1 addition & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _prepare_prompt(
slot_mapping[-1].append(slot)

max_prompt_len = max(subquery_lens)
assert max_prompt_len > 0
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
Expand Down
Loading