Skip to content

[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC #13949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
af8cd05
implementation first-pass
afeldman-nm Feb 27, 2025
1d32c92
Removed assert against plp+apc
afeldman-nm Feb 27, 2025
57aadb6
Merge branch 'main' into plp_apc
afeldman-nm Feb 27, 2025
14bce69
logprobs enum
afeldman-nm Feb 27, 2025
e063bda
refactor
afeldman-nm Feb 27, 2025
d107c92
APC tests
afeldman-nm Feb 27, 2025
6bf8810
Merge branch 'main' into plp_apc
afeldman-nm Feb 27, 2025
1d68b46
refactor
afeldman-nm Feb 27, 2025
135b01d
Merge branch 'main' into plp_apc
afeldman-nm Feb 28, 2025
80a2057
fix
afeldman-nm Feb 28, 2025
829a26c
merge
afeldman-nm Mar 2, 2025
468c104
merge
afeldman-nm Mar 3, 2025
509ac9c
Merge branch 'main' into plp_apc_merge
afeldman-nm Mar 3, 2025
5cafa80
Merge branch 'main' into plp_apc
afeldman-nm Mar 4, 2025
51889d1
revise
afeldman-nm Mar 4, 2025
f4b9d6f
revise
afeldman-nm Mar 4, 2025
18c3b3a
revise
afeldman-nm Mar 4, 2025
718b823
Merge branch 'main' into plp_apc_merge
afeldman-nm Mar 4, 2025
e1c15f5
Merge branch 'main' into plp_apc
afeldman-nm Mar 4, 2025
7b551f5
cody-recommended approach
afeldman-nm Mar 4, 2025
34f425c
removed logic to disable common prefix for prompt logprobs
afeldman-nm Mar 4, 2025
2f79888
Merge branch 'main' into plp_apc_merge
afeldman-nm Mar 4, 2025
fe9e655
merge
afeldman-nm Mar 6, 2025
c178d38
merge
afeldman-nm Mar 6, 2025
d810d9b
Merge branch 'main' into plp_apc_merge
afeldman-nm Mar 7, 2025
6ec9a6a
remove tests
afeldman-nm Mar 7, 2025
f45f0e5
utils cleanup
afeldman-nm Mar 7, 2025
109fdfd
Merge branch 'main' into plp_apc_merge
afeldman-nm Mar 7, 2025
23c27ac
Merge branch 'plp_apc_merge' into plp_apc
afeldman-nm Mar 7, 2025
843b4d0
scheduler tests
afeldman-nm Mar 7, 2025
1d746ff
computed blocks fix
afeldman-nm Mar 7, 2025
077b51f
APC prefill unit test
afeldman-nm Mar 7, 2025
4e88c10
Merge branch 'main' into plp_apc_merge
afeldman-nm Mar 7, 2025
d28dc87
simplify
afeldman-nm Mar 7, 2025
4607a6a
comments
afeldman-nm Mar 7, 2025
5546bff
removed some tests
afeldman-nm Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 110 additions & 2 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""

from typing import Optional

import pytest

from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
Expand All @@ -15,7 +17,8 @@
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
mm_hashes=None,
prompt_logprobs: Optional[int] = None):
if mm_positions is None:
multi_modal_inputs = None
else:
Expand All @@ -28,7 +31,8 @@ def make_request(request_id,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs),
eos_token_id=100,
arrival_time=0,
lora_request=None,
Expand Down Expand Up @@ -144,6 +148,110 @@ def test_prefill():
assert manager.block_pool.free_block_queue.free_list_tail is None


def test_prefill_plp():
'''Test prefill with APC and some prompt logprobs (plp) requests.

1. Schedule plp request and validate APC block allocation
2. Schedule non-plp request and validate blocks
3. Schedule plp request; no hit should occur; validate blocks
'''
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)

# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]

# Request #0 is a prompt logprobs request
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks]

# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value

# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1

# Request #1 is a non-prompt-logprobs request:
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
for block in computed_blocks:
assert block.ref_cnt == 2

# At this point, we should have 3 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3

manager.free(req0)
manager.free(req1)

# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]

# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
unique_token_ids = [3] * 6
req2 = make_request("2",
common_token_ids + unique_token_ids,
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)
block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [0, 1, 2, 3, 4]

# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids:
assert manager.block_pool.blocks[block_id].ref_cnt == 1

manager.free(req2)


def test_decode():
manager = KVCacheManager(
block_size=16,
Expand Down
61 changes: 49 additions & 12 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import pytest

from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
Expand All @@ -16,7 +18,21 @@ def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
) -> Scheduler:
'''Create scheduler under test.

Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)

Returns:
:class:`Scheduler` instance
'''
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
Expand All @@ -31,11 +47,16 @@ def create_scheduler(
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
Expand All @@ -54,16 +75,16 @@ def create_scheduler(
)


def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
):
def create_requests(num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids)
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
Expand Down Expand Up @@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1


def test_schedule():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10,
prompt_logprobs=prompt_logprobs)
for request in requests:
scheduler.add_request(request)

Expand Down Expand Up @@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]


def test_schedule_concurrent_batches():
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=2,
enable_prefix_caching=enable_prefix_caching,
)
requests = create_requests(
num_requests=2,
num_tokens=512,
prompt_logprobs=prompt_logprobs,
)

# Schedule the first request.
Expand Down
36 changes: 0 additions & 36 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pytest

from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import SamplingParams
from vllm.assets.image import ImageAsset
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM,
return count, request_id


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_async_llm_refuses_prompt_logprobs_with_apc(
monkeypatch, output_kind: RequestOutputKind):
"""Test passes if AsyncLLM raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
# better way to test V1 so that in the future when we switch, we don't
# have to change all the tests.
monkeypatch.setenv("VLLM_USE_V1", "1")
# Create AsyncLLM engine with APC
apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.8,
disable_log_requests=True)
engine = AsyncLLM.from_engine_args(apc_engine_args)
try:
with pytest.raises(ValueError) as excinfo:
# Issue a request with prompt logprobs enabled, which should fail
await asyncio.create_task(
generate(engine,
"request-0",
TEXT_PROMPT,
output_kind,
10,
prompt_logprobs=5))
# Validate exception string is correct
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
finally:
# Shut down engine
engine.shutdown()


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args_and_prompt",
Expand Down
15 changes: 0 additions & 15 deletions tests/v1/engine/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import pytest

from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import LLM, SamplingParams

MODEL = "facebook/opt-125m"
Expand Down Expand Up @@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
raise AssertionError(
f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}")


def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
model: LLM = vllm_model_apc.model
with pytest.raises(ValueError) as excinfo:
model.generate(
"Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))

# Validate exception string is correct
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
3 changes: 0 additions & 3 deletions tests/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
PROMPT_LEN = 5

PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")

random.seed(42)


Expand Down
Loading