Skip to content

Commit a39a8a7

Browse files
sagiahracalbertoperdomo2
authored andcommitted
[Prefix Cache] Use LoRA name for consistent KV-cache block hashing (vllm-project#27211)
Signed-off-by: Sage Ahrac <sagiahrak@gmail.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 0e7677c commit a39a8a7

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import vllm.v1.core.kv_cache_utils as kv_cache_utils
1010
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
11+
from vllm.lora.request import LoRARequest
1112
from vllm.multimodal.inputs import (
1213
MultiModalFeatureSpec,
1314
MultiModalKwargsItem,
@@ -449,6 +450,24 @@ def test_generate_block_hash_extra_keys_cache_salt():
449450
assert next_mm_idx == 1
450451

451452

453+
def test_generate_block_hash_extra_keys_lora():
454+
request = make_request(
455+
request_id="0",
456+
prompt_token_ids=[_ for _ in range(6)],
457+
)
458+
459+
request.lora_request = LoRARequest(
460+
lora_name="test_lora_adapter", lora_int_id=1, lora_path="/path/to/lora"
461+
)
462+
463+
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0)
464+
assert extra_keys == ("test_lora_adapter",)
465+
466+
request.lora_request = None
467+
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0)
468+
assert extra_keys is None
469+
470+
452471
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
453472
def test_hash_block_tokens(hash_fn):
454473
parent_block_hash = BlockHash(b"123")

vllm/v1/core/kv_cache_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def need_extra_keys(request: Request) -> bool:
373373
"""
374374

375375
# Multimodal requests need to include the MM hash.
376-
# LoRA requests need to include the LoRA ID.
376+
# LoRA requests need to include the LoRA name.
377377
# Request with provided cache salt need to include the salt.
378378
return (
379379
bool(request.mm_features)
@@ -446,26 +446,26 @@ def _gen_mm_extra_hash_keys(
446446
return extra_keys, curr_mm_idx
447447

448448

449-
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
449+
def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
450450
"""Generate extra keys related to LoRA for block hash computation.
451451
452452
Args:
453453
request: The request object.
454454
455455
Returns:
456-
Return LoRA id of the request if it is a LoRA request. Return empty
456+
Return LoRA name of the request if it is a LoRA request. Return empty
457457
list otherwise.
458458
"""
459459
if not request.lora_request:
460460
return []
461-
return [request.lora_request.lora_int_id]
461+
return [request.lora_request.lora_name]
462462

463463

464464
def generate_block_hash_extra_keys(
465465
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
466466
) -> tuple[tuple[Any, ...] | None, int]:
467467
"""Generate extra keys for the block hash. The extra keys can come from
468-
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
468+
the multi-modal inputs and request specific metadata (e.g., LoRA name).
469469
470470
Args:
471471
request: The request object.
@@ -480,7 +480,7 @@ def generate_block_hash_extra_keys(
480480
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
481481
request, start_token_idx, end_token_idx, start_mm_idx
482482
)
483-
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
483+
lora_extra_keys: list[str] = _gen_lora_extra_hash_keys(request)
484484
cache_salt_keys: list[str] = (
485485
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
486486
)

0 commit comments

Comments
 (0)