Skip to content

Commit bbaf05a

Browse files
sixiang-googleLumosis
authored andcommitted
fix disagg upon block hasher change in vllm (#539)
Co-authored-by: sixiang-google <sixiang-google>
1 parent d1e3c63 commit bbaf05a

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

tests/core/test_core_tpu.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def mock_start():
8787
self.mock_vllm_config = MagicMock(spec=VllmConfig)
8888
self.mock_vllm_config.parallel_config = MagicMock(spec=ParallelConfig)
8989
self.mock_vllm_config.device_config = MagicMock()
90+
self.mock_vllm_config.cache_config = MagicMock()
91+
self.mock_vllm_config.cache_config.prefix_caching_hash_algo = "builtin"
92+
self.mock_vllm_config.cache_config.block_size = 5
9093
self.mock_vllm_config.__post_init__ = MagicMock()
9194

9295
def test_initialization(self):
@@ -128,6 +131,7 @@ def test_add_request(self):
128131
mock_request.use_structured_output = False
129132
mock_request.pooling_params = None
130133
mock_request.sampling_params.guided_decoding = None
134+
mock_request.block_hashes = []
131135

132136
mock_engine_request, _ = proc.preprocess_add_request(mock_request)
133137

@@ -173,6 +177,7 @@ def test_handle_client_request_add(self):
173177
mock_request.use_structured_output = False
174178
mock_request.pooling_params = None
175179
mock_request.sampling_params.guided_decoding = None
180+
mock_request.block_hashes = []
176181
mock_request = proc.preprocess_add_request(mock_request)
177182

178183
proc._handle_client_request(EngineCoreRequestType.ADD, mock_request)
@@ -330,7 +335,7 @@ def test_transfer_logic(self):
330335
orchestrator.live = True
331336

332337
# Mock kv cache map
333-
mock_kv_cache_map = {"test_req": [MagicMock()]}
338+
mock_kv_cache_map = {"test_req": ([MagicMock()], [])}
334339
orchestrator._transfer_backlogs[0].put(mock_kv_cache_map)
335340
orchestrator._transfer_backlogs[0].put(
336341
None) # Sentinel to stop the loop
@@ -354,7 +359,11 @@ def test_decode_logic(self):
354359
orchestrator.live = True
355360

356361
# Mock prefill output
357-
mock_prefill_output = {"req_id": "test_req", "cache": [MagicMock()]}
362+
mock_prefill_output = {
363+
"req_id": "test_req",
364+
"cache": [MagicMock()],
365+
"block_hashes": []
366+
}
358367
orchestrator._decode_backlogs[0].put(mock_prefill_output)
359368
orchestrator._decode_backlogs[0].put(None) # Sentinel to stop the loop
360369

tpu_commons/core/adapters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def sampling_params(self):
138138
def lora_request(self):
139139
return self._vllm_request.lora_request
140140

141+
@property
142+
def block_hashes(self):
143+
return self._vllm_request.block_hashes
144+
141145
@status.setter
142146
def status(self, value: RequestStatus) -> None:
143147
self._vllm_request.status = value

tpu_commons/core/core_tpu.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import threading
99
import time
1010
import traceback
11-
from typing import Any, Callable, Optional, TypeVar, Union
11+
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
1212

1313
import jax
1414
# ======================================================================================
@@ -17,6 +17,9 @@
1717
from vllm.config import VllmConfig
1818
from vllm.logger import init_logger
1919
from vllm.tasks import POOLING_TASKS
20+
from vllm.utils import get_hash_fn_by_name
21+
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
22+
init_none_hash)
2023
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
2124
EngineCoreRequestType, UtilityOutput,
2225
UtilityResult)
@@ -173,7 +176,7 @@ def _prefill(self, idx: int):
173176
if scheduler_output.total_num_scheduled_tokens > 0:
174177
logger.debug(f"Prefill result: {model_output}")
175178

176-
kv_cache_map: dict[str, list[jax.Array]] = {}
179+
kv_cache_map: dict[str, Tuple(list[jax.Array], list[Any])] = {}
177180
for req_id, idx in model_output.req_id_to_index.items():
178181
if len(model_output.sampled_token_ids[idx]) > 0:
179182
request = self._requests[req_id]
@@ -185,7 +188,7 @@ def _prefill(self, idx: int):
185188
kv_cache_map[req_id] = (
186189
prefill_engine.model_executor.driver_worker.
187190
model_runner.get_kv_cache_for_block_ids(
188-
block_ids[0]))
191+
block_ids[0]), request.block_hashes)
189192
logger.debug(f"prefill done: for {req_id}")
190193
transfer_backlog.put(kv_cache_map, block=True)
191194

@@ -196,6 +199,9 @@ def _prefill(self, idx: int):
196199
for req_id, idx in model_output.req_id_to_index.items():
197200
if len(model_output.sampled_token_ids[idx]) > 0:
198201
request = self._requests[req_id].vllm_request
202+
logger.debug(
203+
f"request block_hashes at prefill: {request.block_hashes}"
204+
)
199205
logger.debug(
200206
f"request-{req_id}: tokens={request.all_token_ids} after prefill"
201207
)
@@ -229,7 +235,7 @@ def _transfer(self, idx: int):
229235
)
230236

231237
push_targets = []
232-
for req_id, kv_cache in kv_cachce_map.items():
238+
for req_id, (kv_cache, block_hashes) in kv_cachce_map.items():
233239
target_idx = -1
234240
cnt = 9999999
235241
for i, e in enumerate(self._decode_engines):
@@ -248,6 +254,7 @@ def _transfer(self, idx: int):
248254
prefill_output = {
249255
"cache": kv_cache,
250256
"req_id": req_id,
257+
"block_hashes": block_hashes,
251258
}
252259
push_targets.append((target_idx, prefill_output))
253260

@@ -313,6 +320,8 @@ def _decode(self, idx: int):
313320
)
314321
vllm_request.num_computed_tokens = prompt_tokens
315322
new_block_ids = kv_cache_manager.get_block_ids(req_id)
323+
logger.warning(
324+
f"decoding {req_id} new_block_ids {new_block_ids}")
316325
assert (len(new_block_ids[0]) == math.ceil(
317326
prompt_tokens / self._config.cache_config.block_size))
318327

@@ -321,6 +330,8 @@ def _decode(self, idx: int):
321330
vllm_request, kv_cache, new_block_ids)
322331

323332
vllm_request.status = RequestStatus.RUNNING
333+
block_hashes = prefill_output["block_hashes"]
334+
vllm_request.block_hashes = block_hashes
324335
decode_engine.scheduler.running.append(vllm_request)
325336
decode_engine.scheduler.requests[req_id] = vllm_request
326337

@@ -476,6 +487,17 @@ def executor_fail_callback():
476487
if addresses.coordinator_input is not None:
477488
logger.info("Waiting for READY message from DP Coordinator...")
478489

490+
if (self.vllm_config.cache_config.enable_prefix_caching
491+
or self.scheduler.get_kv_connector() is not None):
492+
493+
block_size = vllm_config.cache_config.block_size
494+
caching_hash_fn = get_hash_fn_by_name(
495+
vllm_config.cache_config.prefix_caching_hash_algo)
496+
init_none_hash(caching_hash_fn)
497+
498+
self.request_block_hasher = get_request_block_hasher(
499+
block_size, caching_hash_fn)
500+
479501
self._orchestrator = _DisaggOrchestrator(
480502
config=VllmConfigAdapter(vllm_config),
481503
output_queue=self.output_queue,
@@ -489,8 +511,6 @@ def executor_fail_callback():
489511
decode_slice_sizes=decode_slice_sizes,
490512
)
491513

492-
self.request_block_hasher = None
493-
494514
@staticmethod
495515
def _create_engine_cores(
496516
slice_sizes: tuple[int, ...],
@@ -532,7 +552,6 @@ def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
532552
def _handle_client_request(self, request_type: EngineCoreRequestType,
533553
request: Any) -> None:
534554
"""Dispatch request from client."""
535-
536555
if request_type == EngineCoreRequestType.ADD:
537556
req, request_wave = request
538557
self.add_request(req)

0 commit comments

Comments
 (0)