Skip to content

Commit 49a48d0

Browse files
orozerydivakar-amd
authored andcommitted
[v1] Move block_hashes from KVCacheManager to Request.block_hashes (vllm-project#19728)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
1 parent c587b1b commit 49a48d0

19 files changed

+382
-336
lines changed

tests/v1/core/test_async_scheduler.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.v1.core.sched.output import SchedulerOutput
88
from vllm.v1.outputs import ModelRunnerOutput
99
from vllm.v1.request import RequestStatus
10+
from vllm.v1.utils import ConstantList
1011

1112
from .utils import create_requests, create_scheduler
1213

@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
140141
requests = create_requests(num_requests=5,
141142
num_tokens=num_prompt_tokens,
142143
max_tokens=3,
143-
same_prompt=True)
144+
same_prompt=True,
145+
block_size=BLOCK_SIZE)
144146
requests_copy = requests.copy()
145147

146148
# Two requests with the same prompt.
@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
188190
block_size=BLOCK_SIZE)
189191
requests = create_requests(num_requests=5,
190192
num_tokens=num_prompt_tokens,
191-
max_tokens=num_output_tokens)
193+
max_tokens=num_output_tokens,
194+
block_size=BLOCK_SIZE)
192195

193196
for req in requests:
194197
scheduler.add_request(req)
@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
208211

209212
# Create next-turn requests whose prompts are the full output of the
210213
# previous turn.
211-
next_turn_requests = create_requests(
212-
num_requests=5,
213-
num_tokens=num_prompt_tokens + num_output_tokens,
214-
max_tokens=num_output_tokens,
215-
)
214+
next_turn_requests = create_requests(num_requests=5,
215+
num_tokens=num_prompt_tokens +
216+
num_output_tokens,
217+
max_tokens=num_output_tokens,
218+
block_size=BLOCK_SIZE)
216219
for i, req in enumerate(next_turn_requests):
217220
req.prompt_token_ids = (requests[i].prompt_token_ids +
218221
list(requests[i].output_token_ids))
222+
req._all_token_ids = req.prompt_token_ids.copy()
223+
req.all_token_ids = ConstantList(req._all_token_ids)
224+
req.block_hashes = []
225+
req.block_hashes = req.get_hash_new_full_blocks()
226+
219227
# Schedule the next-turn requests.
220228
for req in next_turn_requests:
221229
scheduler.add_request(req)

tests/v1/core/test_kv_cache_utils.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import importlib
4-
from typing import Optional
4+
from typing import Callable, Optional
55

66
import pytest
77
import torch
@@ -19,7 +19,7 @@
1919
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
2020
estimate_max_model_len, generate_block_hash_extra_keys,
2121
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
22-
hash_block_tokens, hash_request_tokens, init_none_hash,
22+
get_request_block_hasher, hash_block_tokens, init_none_hash,
2323
is_kv_cache_type_uniform, unify_kv_cache_configs)
2424
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2525
KVCacheGroupSpec, KVCacheTensor,
@@ -33,6 +33,8 @@
3333
def make_request(
3434
request_id: str,
3535
prompt_token_ids: list[int],
36+
block_size: int = 3,
37+
hash_fn: Callable = hash,
3638
mm_positions: Optional[list[PlaceholderRange]] = None,
3739
mm_hashes: Optional[list[str]] = None,
3840
cache_salt: Optional[str] = None,
@@ -49,18 +51,17 @@ def make_request(
4951
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
5052
mm_kwargs = [mm_item] * len(mm_positions)
5153

52-
return Request(
53-
request_id=request_id,
54-
prompt_token_ids=prompt_token_ids,
55-
multi_modal_kwargs=mm_kwargs,
56-
multi_modal_hashes=mm_hashes,
57-
multi_modal_placeholders=mm_positions,
58-
sampling_params=SamplingParams(max_tokens=17),
59-
pooling_params=None,
60-
eos_token_id=100,
61-
lora_request=None,
62-
cache_salt=cache_salt,
63-
)
54+
return Request(request_id=request_id,
55+
prompt_token_ids=prompt_token_ids,
56+
multi_modal_kwargs=mm_kwargs,
57+
multi_modal_hashes=mm_hashes,
58+
multi_modal_placeholders=mm_positions,
59+
sampling_params=SamplingParams(max_tokens=17),
60+
pooling_params=None,
61+
eos_token_id=100,
62+
lora_request=None,
63+
cache_salt=cache_salt,
64+
block_hasher=get_request_block_hasher(block_size, hash_fn))
6465

6566

6667
def new_kv_cache_spec(block_size=16,
@@ -428,22 +429,22 @@ def test_hash_block_tokens(hash_fn):
428429

429430

430431
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
431-
def test_hash_request_tokens(hash_fn):
432+
def test_request_block_hasher(hash_fn):
432433
import vllm.v1.core.kv_cache_utils
433434
init_none_hash(hash_fn)
434435
request = make_request(
435436
request_id="0",
436437
prompt_token_ids=[_ for _ in range(6)],
438+
block_size=3,
439+
hash_fn=hash_fn,
437440
mm_positions=[
438441
PlaceholderRange(offset=0, length=3),
439442
PlaceholderRange(offset=3, length=3),
440443
],
441444
mm_hashes=["hash1", "hash2"],
442445
)
443446

444-
block_size = 3
445-
block_hashes = hash_request_tokens(hash_fn, block_size, request)
446-
447+
block_hashes = request.block_hashes
447448
assert len(block_hashes) == 2
448449
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
449450
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
464465
request1 = make_request(
465466
request_id="0",
466467
prompt_token_ids=[_ for _ in range(6)],
468+
block_size=3,
469+
hash_fn=hash_fn,
467470
mm_positions=[
468471
PlaceholderRange(offset=0, length=3),
469472
PlaceholderRange(offset=3, length=3),
@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
479482
],
480483
mm_hashes=["hash3", "hash2"],
481484
)
482-
block_size = 3
483-
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
484-
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
485+
block_hashes1 = request1.block_hashes
486+
block_hashes2 = request2.block_hashes
485487
assert block_hashes1[0] != block_hashes2[0]
486488
assert block_hashes1[1] != block_hashes2[1]
487489

@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
493495
request = make_request(
494496
request_id="0",
495497
prompt_token_ids=[_ for _ in range(6)],
498+
block_size=3,
499+
hash_fn=hash_fn,
496500
mm_positions=None,
497501
mm_hashes=None,
498502
)
499503

500-
block_size = 3
501-
block_hashes = hash_request_tokens(hash_fn, block_size, request)
504+
block_hashes = request.block_hashes
502505

503506
assert len(block_hashes) == 2
504507
assert block_hashes[0].token_ids == (0, 1, 2)
@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
858861
request = make_request(
859862
request_id="0",
860863
prompt_token_ids=[],
864+
block_size=block_size,
861865
mm_positions=None,
862866
mm_hashes=None,
863867
)

0 commit comments

Comments
 (0)