88import threading
99import time
1010import traceback
11- from typing import Any , Callable , Optional , TypeVar , Union
11+ from typing import Any , Callable , Optional , Tuple , TypeVar , Union
1212
1313import jax
1414# ======================================================================================
1717from vllm .config import VllmConfig
1818from vllm .logger import init_logger
1919from vllm .tasks import POOLING_TASKS
20+ from vllm .utils import get_hash_fn_by_name
21+ from vllm .v1 .core .kv_cache_utils import (get_request_block_hasher ,
22+ init_none_hash )
2023from vllm .v1 .engine import (EngineCoreOutputs , EngineCoreRequest ,
2124 EngineCoreRequestType , UtilityOutput ,
2225 UtilityResult )
@@ -173,7 +176,7 @@ def _prefill(self, idx: int):
173176 if scheduler_output .total_num_scheduled_tokens > 0 :
174177 logger .debug (f"Prefill result: { model_output } " )
175178
176- kv_cache_map : dict [str , list [jax .Array ]] = {}
179+ kv_cache_map : dict [str , Tuple ( list [jax .Array ], list [ Any ]) ] = {}
177180 for req_id , idx in model_output .req_id_to_index .items ():
178181 if len (model_output .sampled_token_ids [idx ]) > 0 :
179182 request = self ._requests [req_id ]
@@ -185,7 +188,7 @@ def _prefill(self, idx: int):
185188 kv_cache_map [req_id ] = (
186189 prefill_engine .model_executor .driver_worker .
187190 model_runner .get_kv_cache_for_block_ids (
188- block_ids [0 ]))
191+ block_ids [0 ]), request . block_hashes )
189192 logger .debug (f"prefill done: for { req_id } " )
190193 transfer_backlog .put (kv_cache_map , block = True )
191194
@@ -196,6 +199,9 @@ def _prefill(self, idx: int):
196199 for req_id , idx in model_output .req_id_to_index .items ():
197200 if len (model_output .sampled_token_ids [idx ]) > 0 :
198201 request = self ._requests [req_id ].vllm_request
202+ logger .debug (
203+ f"request block_hashes at prefill: { request .block_hashes } "
204+ )
199205 logger .debug (
200206 f"request-{ req_id } : tokens={ request .all_token_ids } after prefill"
201207 )
@@ -229,7 +235,7 @@ def _transfer(self, idx: int):
229235 )
230236
231237 push_targets = []
232- for req_id , kv_cache in kv_cachce_map .items ():
238+ for req_id , ( kv_cache , block_hashes ) in kv_cachce_map .items ():
233239 target_idx = - 1
234240 cnt = 9999999
235241 for i , e in enumerate (self ._decode_engines ):
@@ -248,6 +254,7 @@ def _transfer(self, idx: int):
248254 prefill_output = {
249255 "cache" : kv_cache ,
250256 "req_id" : req_id ,
257+ "block_hashes" : block_hashes ,
251258 }
252259 push_targets .append ((target_idx , prefill_output ))
253260
@@ -313,6 +320,8 @@ def _decode(self, idx: int):
313320 )
314321 vllm_request .num_computed_tokens = prompt_tokens
315322 new_block_ids = kv_cache_manager .get_block_ids (req_id )
323+ logger .warning (
324+ f"decoding { req_id } new_block_ids { new_block_ids } " )
316325 assert (len (new_block_ids [0 ]) == math .ceil (
317326 prompt_tokens / self ._config .cache_config .block_size ))
318327
@@ -321,6 +330,8 @@ def _decode(self, idx: int):
321330 vllm_request , kv_cache , new_block_ids )
322331
323332 vllm_request .status = RequestStatus .RUNNING
333+ block_hashes = prefill_output ["block_hashes" ]
334+ vllm_request .block_hashes = block_hashes
324335 decode_engine .scheduler .running .append (vllm_request )
325336 decode_engine .scheduler .requests [req_id ] = vllm_request
326337
@@ -476,6 +487,17 @@ def executor_fail_callback():
476487 if addresses .coordinator_input is not None :
477488 logger .info ("Waiting for READY message from DP Coordinator..." )
478489
490+ if (self .vllm_config .cache_config .enable_prefix_caching
491+ or self .scheduler .get_kv_connector () is not None ):
492+
493+ block_size = vllm_config .cache_config .block_size
494+ caching_hash_fn = get_hash_fn_by_name (
495+ vllm_config .cache_config .prefix_caching_hash_algo )
496+ init_none_hash (caching_hash_fn )
497+
498+ self .request_block_hasher = get_request_block_hasher (
499+ block_size , caching_hash_fn )
500+
479501 self ._orchestrator = _DisaggOrchestrator (
480502 config = VllmConfigAdapter (vllm_config ),
481503 output_queue = self .output_queue ,
@@ -489,8 +511,6 @@ def executor_fail_callback():
489511 decode_slice_sizes = decode_slice_sizes ,
490512 )
491513
492- self .request_block_hasher = None
493-
494514 @staticmethod
495515 def _create_engine_cores (
496516 slice_sizes : tuple [int , ...],
@@ -532,7 +552,6 @@ def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
532552 def _handle_client_request (self , request_type : EngineCoreRequestType ,
533553 request : Any ) -> None :
534554 """Dispatch request from client."""
535-
536555 if request_type == EngineCoreRequestType .ADD :
537556 req , request_wave = request
538557 self .add_request (req )
0 commit comments