Skip to content

Commit bddb2f0

Browse files
committed
done
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
1 parent efbce85 commit bddb2f0

File tree

8 files changed

+297
-43
lines changed

8 files changed

+297
-43
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
import pytest
33

44
from vllm.inputs import token_inputs
5+
from vllm.multimodal.inputs import PlaceholderRange
56
from vllm.sampling_params import SamplingParams
67
from vllm.utils import cdiv
78
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
89
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
910

1011

11-
def make_request(request_id, prompt_token_ids):
12+
def make_request(request_id,
13+
prompt_token_ids,
14+
mm_positions=None,
15+
mm_hashes=None):
1216
return Request(
1317
request_id=request_id,
14-
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
18+
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
19+
multi_modal_placeholders={"image": mm_positions}
20+
if mm_positions else None,
21+
multi_modal_hashes=mm_hashes),
1522
sampling_params=SamplingParams(max_tokens=17),
1623
eos_token_id=100,
1724
arrival_time=0,
@@ -38,6 +45,7 @@ def test_prefill():
3845
all_token_ids = common_token_ids + unique_token_ids
3946
req0 = make_request("0", all_token_ids)
4047
computed_blocks = manager.get_computed_blocks(req0)
48+
assert len(req0.kv_block_hashes) == 3
4149
assert not computed_blocks
4250
blocks = manager.allocate_slots(req0, 55, computed_blocks)
4351
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
@@ -61,6 +69,7 @@ def test_prefill():
6169
unique_token_ids = [3] * 5
6270
req1 = make_request("1", common_token_ids + unique_token_ids)
6371
computed_blocks = manager.get_computed_blocks(req1)
72+
assert len(req1.kv_block_hashes) == 3
6473
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
6574
num_new_tokens = 53 - 3 * 16
6675
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
@@ -90,6 +99,7 @@ def test_prefill():
9099
unique_token_ids = [3] * 6
91100
req2 = make_request("2", common_token_ids + unique_token_ids)
92101
computed_block = manager.get_computed_blocks(req2)
102+
assert len(req2.kv_block_hashes) == 3
93103
assert [b.block_id for b in computed_block] == [0, 1, 2]
94104
num_new_tokens = 53 - 3 * 16
95105
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
@@ -416,3 +426,77 @@ def test_cache_blocks():
416426
)
417427
assert len(manager.cached_block_hash_to_block) == 3
418428
assert blocks[0].block_hash is not None
429+
430+
431+
def test_mm_prefix_caching():
432+
"""
433+
This tests that the multi-modal prefix caching is correct.
434+
"""
435+
manager = KVCacheManager(
436+
block_size=16,
437+
num_gpu_blocks=10,
438+
max_model_len=8192,
439+
sliding_window=None,
440+
enable_caching=True,
441+
num_preallocate_tokens=16,
442+
)
443+
444+
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
445+
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
446+
common_token_ids = list(range(10)) + [-1] * 6
447+
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
448+
common_token_ids += [-1] * 16
449+
450+
common_mm_positions = [
451+
PlaceholderRange(offset=11, length=10),
452+
PlaceholderRange(offset=30, length=18),
453+
]
454+
common_mm_hashes = ["aaa", "bbb"]
455+
456+
# A unique image plus some text tokens.
457+
unique_token_ids = [-1] * 7 + [100] * 4
458+
all_token_ids = common_token_ids + unique_token_ids
459+
mm_positions = common_mm_positions + [
460+
PlaceholderRange(offset=48, length=7)
461+
]
462+
mm_hashes = common_mm_hashes + ["ccc"]
463+
req0 = make_request("0",
464+
all_token_ids,
465+
mm_positions=mm_positions,
466+
mm_hashes=mm_hashes)
467+
computed_blocks = manager.get_computed_blocks(req0)
468+
469+
# Completed block should have hashes with extra keys.
470+
assert not computed_blocks
471+
assert len(req0.kv_block_hashes) == 3
472+
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
473+
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
474+
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )
475+
476+
blocks = manager.allocate_slots(req0, 59, computed_blocks)
477+
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
478+
req0.num_computed_tokens = 59
479+
480+
# Append slots without allocating a new block.
481+
for _ in range(5):
482+
req0.append_output_token_ids(8)
483+
new_blocks = manager.append_slots(req0, 5)
484+
assert new_blocks is not None and len(new_blocks) == 0
485+
486+
# The just completed block should have hashes with extra keys.
487+
assert len(req0.kv_block_hashes) == 4
488+
assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), )
489+
490+
# Cache hit.
491+
unique_token_ids = [-1] * 7 + [200] * 5
492+
all_token_ids = common_token_ids + unique_token_ids
493+
mm_positions = common_mm_positions + [
494+
PlaceholderRange(offset=48, length=7)
495+
]
496+
mm_hashes = common_mm_hashes + ["ccc"]
497+
req1 = make_request("1",
498+
all_token_ids,
499+
mm_positions=mm_positions,
500+
mm_hashes=mm_hashes)
501+
computed_blocks = manager.get_computed_blocks(req1)
502+
assert len(computed_blocks) == 3

vllm/engine/arg_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,11 +1027,11 @@ def create_engine_config(self,
10271027
device_config = DeviceConfig(device=self.device)
10281028
model_config = self.create_model_config()
10291029

1030-
if model_config.is_multimodal_model:
1030+
if model_config.is_multimodal_model and not envs.VLLM_USE_V1:
10311031
if self.enable_prefix_caching:
1032-
logger.warning(
1033-
"--enable-prefix-caching is currently not "
1034-
"supported for multimodal models and has been disabled.")
1032+
logger.warning("--enable-prefix-caching is currently not "
1033+
"supported for multimodal models in v0 and "
1034+
"has been disabled.")
10351035
self.enable_prefix_caching = False
10361036

10371037
cache_config = CacheConfig(

vllm/inputs/data.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
162162
Placeholder ranges for the multi-modal data.
163163
"""
164164

165+
multi_modal_hashes: NotRequired[List[Optional[str]]]
166+
"""
167+
The hashes of the multi-modal data.
168+
"""
169+
165170
mm_processor_kwargs: NotRequired[Dict[str, Any]]
166171
"""
167172
Optional multi-modal processor kwargs to be forwarded to the
@@ -177,6 +182,7 @@ def token_inputs(
177182
prompt: Optional[str] = None,
178183
multi_modal_data: Optional["MultiModalDataDict"] = None,
179184
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
185+
multi_modal_hashes: Optional[List[Optional[str]]] = None,
180186
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
181187
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
182188
) -> TokenInputs:
@@ -191,6 +197,8 @@ def token_inputs(
191197
inputs["multi_modal_data"] = multi_modal_data
192198
if multi_modal_inputs is not None:
193199
inputs["multi_modal_inputs"] = multi_modal_inputs
200+
if multi_modal_hashes is not None:
201+
inputs["multi_modal_hashes"] = multi_modal_hashes
194202
if multi_modal_placeholders is not None:
195203
inputs["multi_modal_placeholders"] = multi_modal_placeholders
196204
if mm_processor_kwargs is not None:
@@ -295,6 +303,18 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
295303

296304
assert_never(inputs)
297305

306+
@cached_property
307+
def multi_modal_hashes(self) -> List[Optional[str]]:
308+
inputs = self.inputs
309+
310+
if inputs["type"] == "token":
311+
return inputs.get("multi_modal_hashes", [])
312+
313+
if inputs["type"] == "multimodal":
314+
return inputs.get("mm_hashes", [])
315+
316+
assert_never(inputs)
317+
298318
@cached_property
299319
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
300320
inputs = self.inputs

vllm/multimodal/inputs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import UserDict, defaultdict
2-
from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple,
3-
TypedDict, TypeVar, Union, cast, final)
2+
from typing import (Any, Dict, List, Literal, Mapping, Optional, Sequence,
3+
Tuple, TypedDict, TypeVar, Union, cast, final)
44

55
import numpy as np
66
import torch
@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
215215
mm_kwargs: MultiModalKwargs
216216
"""Keyword arguments to be directly passed to the model after batching."""
217217

218+
mm_hashes: NotRequired[List[Optional[str]]]
219+
"""The hashes of the multi-modal data."""
220+
218221
mm_placeholders: MultiModalPlaceholderDict
219222
"""
220223
For each modality, information about the placeholder tokens in

vllm/v1/core/kv_cache_manager.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from vllm.logger import init_logger
55
from vllm.utils import cdiv
66
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
7-
KVCacheBlock, hash_block_tokens,
7+
KVCacheBlock,
8+
generate_block_hash_extra_keys,
9+
hash_block_tokens,
810
hash_request_tokens)
911
from vllm.v1.request import Request
1012

@@ -83,10 +85,12 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
8385

8486
computed_blocks = []
8587

86-
# TODO(rickyx): potentially we could cache this so we don't have to
87-
# recompute it every time.
88-
block_hashes = hash_request_tokens(self.block_size,
89-
request.all_token_ids)
88+
# The block hashes for the request may already be computed
89+
# if the request was preempted and resumed.
90+
if not request.kv_block_hashes:
91+
request.kv_block_hashes = hash_request_tokens(
92+
self.block_size, request)
93+
block_hashes = request.kv_block_hashes
9094

9195
for block_hash in block_hashes:
9296
# block_hashes is a chain of block hashes. If a block hash is not
@@ -242,14 +246,16 @@ def allocate_slots(
242246
num_computed_tokens = len(computed_blocks) * self.block_size
243247
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
244248

245-
self._cache_full_blocks(
246-
request=request,
247-
blk_start_idx=len(computed_blocks),
248-
# The new full blocks are the full blocks that are not computed.
249-
full_blocks=self.req_to_blocks[request.request_id]
250-
[len(computed_blocks):num_full_blocks],
251-
prev_block=computed_blocks[-1] if computed_blocks else None,
252-
)
249+
new_full_blocks = self.req_to_blocks[
250+
request.request_id][len(computed_blocks):num_full_blocks]
251+
if new_full_blocks:
252+
self._cache_full_blocks(
253+
request=request,
254+
blk_start_idx=len(computed_blocks),
255+
# The new full blocks are the full blocks that are not computed.
256+
full_blocks=new_full_blocks,
257+
prev_block=computed_blocks[-1] if computed_blocks else None,
258+
)
253259

254260
return new_blocks
255261

@@ -376,6 +382,8 @@ def _cache_full_blocks(
376382
full_blocks: The list of blocks to update hash metadata.
377383
prev_block: The previous block in the chain.
378384
"""
385+
num_cached_block_hashes = len(request.kv_block_hashes)
386+
379387
# Update the new blocks with the block hashes through the chain.
380388
prev_block_hash_value = None
381389
if prev_block is not None:
@@ -387,17 +395,35 @@ def _cache_full_blocks(
387395
for i, blk in enumerate(full_blocks):
388396
blk_idx = blk_start_idx + i
389397

390-
block_tokens = request.all_token_ids[blk_idx *
391-
self.block_size:(blk_idx +
392-
1) *
393-
self.block_size]
394-
assert len(block_tokens) == self.block_size, (
395-
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
396-
f"at {blk_idx}th block for request "
397-
f"{request.request_id}({request})")
398-
399-
# Compute the hash of the current block.
400-
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
398+
if blk_idx < num_cached_block_hashes:
399+
# The block hash may already be computed in
400+
# "get_computed_blocks" if the tokens are not generated by
401+
# this request (either the prompt tokens or the previously
402+
# generated tokens with preemption). In this case we simply
403+
# reuse the block hash.
404+
block_hash = request.kv_block_hashes[blk_idx]
405+
else:
406+
# Otherwise compute the block hash and cache it in the request
407+
# in case it will be preempted in the future.
408+
start_token_idx = blk_idx * self.block_size
409+
end_token_idx = (blk_idx + 1) * self.block_size
410+
block_tokens = request.all_token_ids[
411+
start_token_idx:end_token_idx]
412+
assert len(block_tokens) == self.block_size, (
413+
f"Expected {self.block_size} tokens, got "
414+
f"{len(block_tokens)} at {blk_idx}th block for request "
415+
f"{request.request_id}({request})")
416+
417+
# Generate extra keys for multi-modal inputs. Note that since
418+
# we reach to this branch only when the block is completed with
419+
# generated tokens, we only need to consider the last mm input.
420+
extra_keys, _ = generate_block_hash_extra_keys(
421+
request, start_token_idx, end_token_idx, -1)
422+
423+
# Compute the hash of the current block.
424+
block_hash = hash_block_tokens(prev_block_hash_value,
425+
block_tokens, extra_keys)
426+
request.append_kv_block_hashes(block_hash)
401427

402428
# Update and added the full block to the cache.
403429
blk.block_hash = block_hash

0 commit comments

Comments
 (0)