Skip to content

Commit 77073c7

Browse files
authored
[Core] Prevent side-channel attacks via cache salting (vllm-project#17045)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
1 parent a7d5b01 commit 77073c7

18 files changed

+324
-122
lines changed

docs/source/design/v1/prefix_caching.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ In the example above, the KV cache in the first block can be uniquely identified
1616

1717
* Parent hash value: The hash value of the parent hash block.
1818
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
19-
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below).
19+
* Extra hashes: Other values required to make this block unique, such as LoRA IDs, multi-modality input hashes (see the example below), and cache salts to isolate caches in multi-tenant environments.
2020

2121
> **Note 1:** We only cache full blocks.
2222
@@ -76,6 +76,24 @@ Block 3
7676

7777
In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow.
7878

79+
**Cache Isolation for Security**
80+
To improve privacy in shared environments, vLLM supports isolating prefix cache reuse through optional per-request salting. By including a `cache_salt` in the request, this value is injected into the hash of the first block, ensuring that only requests with the same salt can reuse cached KV blocks. This prevents timing-based attacks where an adversary could infer cached content by observing latency differences. This offers protection without compromising performance.
81+
82+
```json
83+
{
84+
"messages": [
85+
{"role": "system", "content": "You are a helpful assistant."},
86+
{"role": "user", "content": "Here is a document with details about the world series: ..."},
87+
{"role": "user", "content": "Who won the world series in 2020?"}
88+
],
89+
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ=="
90+
}
91+
```
92+
93+
With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others.
94+
95+
> **Note:** Cache isolation is not supported in engine V0.
96+
7997
## Data Structure
8098

8199
The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified):

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config():
272272

273273
assert mock_engine.generate.call_args.args[1].temperature == 0.0
274274
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
275+
276+
277+
def test_serving_chat_did_set_correct_cache_salt():
278+
mock_model_config = MockModelConfig()
279+
280+
mock_engine = MagicMock(spec=MQLLMEngineClient)
281+
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
282+
mock_engine.errored = False
283+
284+
# Initialize the serving chat
285+
models = OpenAIServingModels(engine_client=mock_engine,
286+
base_model_paths=BASE_MODEL_PATHS,
287+
model_config=mock_model_config)
288+
serving_chat = OpenAIServingChat(mock_engine,
289+
mock_model_config,
290+
models,
291+
response_role="assistant",
292+
chat_template=CHAT_TEMPLATE,
293+
chat_template_content_format="auto",
294+
request_logger=None)
295+
296+
# Test cache_salt
297+
req = ChatCompletionRequest(
298+
model=MODEL_NAME,
299+
messages=[{
300+
"role": "user",
301+
"content": "what is 1+1?"
302+
}],
303+
)
304+
305+
# By default cache_salt in the engine prompt is not set
306+
with suppress(Exception):
307+
asyncio.run(serving_chat.create_chat_completion(req))
308+
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
309+
310+
# Test with certain cache_salt
311+
req.cache_salt = "test_salt"
312+
with suppress(Exception):
313+
asyncio.run(serving_chat.create_chat_completion(req))
314+
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"

tests/tokenization/test_detokenize.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,16 @@ def _run_incremental_decode(tokenizer,
6060
skip_special_tokens=skip_special_tokens,
6161
spaces_between_special_tokens=spaces_between_special_tokens,
6262
)
63-
request = EngineCoreRequest("", prompt_token_ids, None, None, None, params,
64-
None, 0.0, None)
63+
request = EngineCoreRequest("",
64+
prompt_token_ids,
65+
None,
66+
None,
67+
None,
68+
params,
69+
None,
70+
0.0,
71+
None,
72+
cache_salt=None)
6573

6674
if fast is None:
6775
detokenizer = IncrementalDetokenizer.from_new_request(

tests/v1/core/test_kv_cache_utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
def make_request(request_id,
3030
prompt_token_ids,
3131
mm_positions=None,
32-
mm_hashes=None):
32+
mm_hashes=None,
33+
cache_salt=None):
3334
if mm_positions is None:
3435
multi_modal_inputs = None
3536
else:
@@ -45,6 +46,7 @@ def make_request(request_id,
4546
eos_token_id=100,
4647
arrival_time=0,
4748
lora_request=None,
49+
cache_salt=cache_salt,
4850
)
4951

5052

@@ -213,6 +215,45 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
213215
assert next_mm_idx == 0
214216

215217

218+
def test_generate_block_hash_extra_keys_cache_salt():
219+
request = make_request(
220+
request_id=0,
221+
prompt_token_ids=[_ for _ in range(6)],
222+
mm_positions=None,
223+
mm_hashes=None,
224+
cache_salt="salt",
225+
)
226+
227+
# salt is added for the first token
228+
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0)
229+
assert extra_keys == ('salt', )
230+
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0)
231+
assert extra_keys == ('salt', )
232+
233+
# no salt added for other tokens
234+
extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0)
235+
assert extra_keys is None
236+
extra_keys, _ = generate_block_hash_extra_keys(request, 6, 10, 0)
237+
assert extra_keys is None
238+
239+
# works together with other extra keys
240+
request_mm = make_request(
241+
request_id=0,
242+
prompt_token_ids=[_ for _ in range(20)],
243+
mm_positions=[
244+
PlaceholderRange(offset=0, length=5),
245+
],
246+
mm_hashes=["hash1"],
247+
cache_salt="salt",
248+
)
249+
250+
# Test with no extra keys
251+
extra_keys, next_mm_idx = generate_block_hash_extra_keys(
252+
request_mm, 0, 5, 0)
253+
assert extra_keys == ("hash1", "salt")
254+
assert next_mm_idx == 1
255+
256+
216257
@pytest.mark.parametrize("hash_fn", [sha256, hash])
217258
def test_hash_block_tokens(hash_fn):
218259
parent_block_hash = 123

tests/v1/core/test_prefix_caching.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def make_request(request_id,
2121
prompt_token_ids,
2222
mm_positions=None,
2323
mm_hashes=None,
24-
prompt_logprobs: Optional[int] = None):
24+
prompt_logprobs: Optional[int] = None,
25+
cache_salt: Optional[str] = None):
2526
if mm_positions is None:
2627
multi_modal_inputs = None
2728
else:
@@ -38,6 +39,7 @@ def make_request(request_id,
3839
eos_token_id=100,
3940
arrival_time=0,
4041
lora_request=None,
42+
cache_salt=cache_salt,
4143
)
4244

4345

@@ -603,6 +605,66 @@ def test_mm_prefix_caching():
603605
assert num_computed_tokens == 3 * 16
604606

605607

608+
def test_cache_key_salting():
609+
"""
610+
This tests that cache salts are applied during hashing and the cache
611+
is separated cache as expected.
612+
"""
613+
block_size = 16
614+
manager = KVCacheManager(
615+
make_kv_cache_config(block_size, 11),
616+
max_model_len=8192,
617+
enable_caching=True,
618+
)
619+
620+
# 3 complete blocks and an incomplete block with 11 tokens.
621+
common_token_ids = [i for i in range(3) for _ in range(block_size)]
622+
token_ids = common_token_ids + [3] * 11
623+
req0 = make_request("0", token_ids, cache_salt="salt1")
624+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
625+
626+
# Completed block should have hashes with extra keys.
627+
assert not computed_blocks
628+
assert num_computed_tokens == 0
629+
block_hashes = manager.req_to_block_hashes[req0.request_id]
630+
assert len(block_hashes) == 3
631+
assert block_hashes[0].extra_keys == ("salt1", )
632+
assert block_hashes[1].extra_keys is None
633+
assert block_hashes[2].extra_keys is None
634+
635+
blocks = manager.allocate_slots(req0, 59, computed_blocks)
636+
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
637+
req0.num_computed_tokens = 59
638+
639+
# Append slots without allocating a new block.
640+
for _ in range(5):
641+
req0.append_output_token_ids(8)
642+
new_blocks = manager.allocate_slots(req0, 5)
643+
assert new_blocks is not None and len(new_blocks) == 0
644+
645+
# Now one more block that should not have extra keys.
646+
assert len(block_hashes) == 4
647+
assert block_hashes[3].extra_keys is None
648+
649+
# Test cache hit with a new request that has the same salt.
650+
token_ids = common_token_ids + [4] * 11
651+
req1 = make_request("1", token_ids, cache_salt="salt1")
652+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
653+
# Should match only a prefix of 3 blocks.
654+
assert len(computed_blocks) == 3
655+
assert num_computed_tokens == 3 * block_size
656+
657+
# Test cache miss with same content but different salt.
658+
token_ids = common_token_ids + [4] * 11
659+
req2 = make_request("2", token_ids, cache_salt="salt2")
660+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
661+
assert len(computed_blocks) == 0
662+
assert num_computed_tokens == 0
663+
block_hashes = manager.req_to_block_hashes[req2.request_id]
664+
assert len(block_hashes) == 3
665+
assert block_hashes[0].extra_keys == ("salt2", )
666+
667+
606668
def test_prefill_not_enough_free_blocks_with_computed_blocks():
607669
"""
608670
This is a unit test that tests the correctness of the allocate_slots

tests/v1/engine/test_engine_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def make_request() -> EngineCoreRequest:
4040
eos_token_id=None,
4141
arrival_time=time.time(),
4242
lora_request=None,
43+
cache_salt=None,
4344
)
4445

4546

tests/v1/engine/test_engine_core_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
4343
eos_token_id=None,
4444
arrival_time=time.time(),
4545
lora_request=None,
46+
cache_salt=None,
4647
)
4748

4849

tests/v1/engine/test_output_processor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
5757
mm_placeholders=None,
5858
eos_token_id=None,
5959
lora_request=None,
60+
cache_salt=None,
6061
sampling_params=SamplingParams(
6162
skip_special_tokens=False,
6263
spaces_between_special_tokens=False,
@@ -403,6 +404,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
403404
mm_placeholders=None,
404405
eos_token_id=None,
405406
lora_request=None,
407+
cache_salt=None,
406408
sampling_params=SamplingParams(
407409
skip_special_tokens=False,
408410
spaces_between_special_tokens=False,
@@ -503,7 +505,7 @@ def test_stop_token(include_stop_str_in_output: bool,
503505
reason should be "stop" (i.e. first control token causes stop
504506
and is represented in output text)
505507
506-
* else, the detokenized string should be
508+
* else, the detokenized string should be
507509
<token><token>...<token> and the finish reason should be "stop"
508510
(i.e. first control token causes stop but is not represented
509511
in output text.)
@@ -565,6 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
565567
mm_placeholders=None,
566568
eos_token_id=eos_token_id,
567569
lora_request=None,
570+
cache_salt=None,
568571
sampling_params=SamplingParams(
569572
skip_special_tokens=False,
570573
spaces_between_special_tokens=False,
@@ -661,6 +664,7 @@ def test_stop_string(include_stop_str_in_output: bool,
661664
mm_placeholders=None,
662665
eos_token_id=None,
663666
lora_request=None,
667+
cache_salt=None,
664668
sampling_params=SamplingParams(
665669
skip_special_tokens=False,
666670
spaces_between_special_tokens=False,
@@ -774,6 +778,7 @@ def test_iteration_stats(dummy_test_vectors):
774778
mm_placeholders=None,
775779
eos_token_id=None,
776780
lora_request=None,
781+
cache_salt=None,
777782
sampling_params=SamplingParams(),
778783
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
779784
]

vllm/entrypoints/openai/protocol.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ValidationInfo, field_validator, model_validator)
1515
from typing_extensions import TypeAlias
1616

17+
from vllm import envs
1718
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
1819
from vllm.logger import init_logger
1920
from vllm.pooling_params import PoolingParams
@@ -408,6 +409,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
408409
"If specified with 'logprobs', tokens are represented "
409410
" as strings of the form 'token_id:{token_id}' so that tokens "
410411
"that are not JSON-encodable can be identified."))
412+
cache_salt: Optional[str] = Field(
413+
default=None,
414+
description=(
415+
"If specified, the prefix cache will be salted with the provided "
416+
"string to prevent an attacker to guess prompts in multi-user "
417+
"environments. The salt should be random, protected from "
418+
"access by 3rd parties, and long enough to be "
419+
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
420+
"to 256 bit). Not supported by vLLM engine V0."))
411421

412422
# doc: end-chat-completion-extra-params
413423

@@ -726,6 +736,20 @@ def check_generation_prompt(cls, data):
726736
"`add_generation_prompt` to True.")
727737
return data
728738

739+
@model_validator(mode="before")
740+
@classmethod
741+
def check_cache_salt_support(cls, data):
742+
if data.get("cache_salt") is not None:
743+
if not envs.VLLM_USE_V1:
744+
raise ValueError(
745+
"Parameter 'cache_salt' is not supported with "
746+
"this instance of vLLM, which uses engine V0.")
747+
if not isinstance(data["cache_salt"],
748+
str) or not data["cache_salt"]:
749+
raise ValueError("Parameter 'cache_salt' must be a "
750+
"non-empty string if provided.")
751+
return data
752+
729753

730754
class CompletionRequest(OpenAIBaseModel):
731755
# Ordered by official OpenAI API documentation
@@ -1622,9 +1646,9 @@ class TranscriptionRequest(OpenAIBaseModel):
16221646

16231647
# doc: begin-transcription-extra-params
16241648
stream: Optional[bool] = False
1625-
"""Custom field not present in the original OpenAI definition. When set,
1649+
"""Custom field not present in the original OpenAI definition. When set,
16261650
it will enable output to be streamed in a similar fashion as the Chat
1627-
Completion endpoint.
1651+
Completion endpoint.
16281652
"""
16291653
# Flattened stream option to simplify form data.
16301654
stream_include_usage: Optional[bool] = False
@@ -1642,15 +1666,15 @@ class TranscriptionRequest(OpenAIBaseModel):
16421666
"""
16431667

16441668
top_p: Optional[float] = None
1645-
"""Enables nucleus (top-p) sampling, where tokens are selected from the
1669+
"""Enables nucleus (top-p) sampling, where tokens are selected from the
16461670
smallest possible set whose cumulative probability exceeds `p`.
16471671
"""
16481672

16491673
top_k: Optional[int] = None
16501674
"""Limits sampling to the `k` most probable tokens at each step."""
16511675

16521676
min_p: Optional[float] = None
1653-
"""Filters out tokens with a probability lower than `min_p`, ensuring a
1677+
"""Filters out tokens with a probability lower than `min_p`, ensuring a
16541678
minimum likelihood threshold during sampling.
16551679
"""
16561680

vllm/entrypoints/openai/serving_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,9 @@ async def _preprocess_chat(
470470
if request.mm_processor_kwargs is not None:
471471
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
472472

473+
if hasattr(request, "cache_salt") and request.cache_salt is not None:
474+
engine_prompt["cache_salt"] = request.cache_salt
475+
473476
return conversation, [request_prompt], [engine_prompt]
474477

475478
def _log_inputs(

0 commit comments

Comments
 (0)