Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
Expand Down Expand Up @@ -449,6 +450,24 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1


def test_generate_block_hash_extra_keys_lora():
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
)

request.lora_request = LoRARequest(
lora_name="test_lora_adapter", lora_int_id=1, lora_path="/path/to/lora"
)

extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0)
assert extra_keys == ("test_lora_adapter",)

request.lora_request = None
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0)
assert extra_keys is None


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens(hash_fn):
parent_block_hash = BlockHash(b"123")
Expand Down
12 changes: 6 additions & 6 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def need_extra_keys(request: Request) -> bool:
"""

# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
# LoRA requests need to include the LoRA name.
# Request with provided cache salt need to include the salt.
return (
bool(request.mm_features)
Expand Down Expand Up @@ -446,26 +446,26 @@ def _gen_mm_extra_hash_keys(
return extra_keys, curr_mm_idx


def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
"""Generate extra keys related to LoRA for block hash computation.

Args:
request: The request object.

Returns:
Return LoRA id of the request if it is a LoRA request. Return empty
Return LoRA name of the request if it is a LoRA request. Return empty
list otherwise.
"""
if not request.lora_request:
return []
return [request.lora_request.lora_int_id]
return [request.lora_request.lora_name]


def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
) -> tuple[tuple[Any, ...] | None, int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
the multi-modal inputs and request specific metadata (e.g., LoRA name).

Args:
request: The request object.
Expand All @@ -480,7 +480,7 @@ def generate_block_hash_extra_keys(
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx
)
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
lora_extra_keys: list[str] = _gen_lora_extra_hash_keys(request)
cache_salt_keys: list[str] = (
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
)
Expand Down