@@ -171,22 +171,10 @@ def free(self, request_id: str) -> None:
171171 self .num_cached_block .pop (request_id , None )
172172
173173 @abstractmethod
174- def get_num_common_prefix_blocks (self , request_id : str ,
175- num_running_requests : int ) -> int :
176- """
177- Get the number of common prefix blocks for all requests in the RUNNING
178- state.
179-
180- Args:
181- request_id: The request ID.
182- num_running_requests: The total number of requests in the RUNNING
183- state.
184-
185- Returns:
186- The number of common prefix blocks for all requests in the RUNNING
187- state.
188- """
189-
174+ def get_num_common_prefix_blocks (
175+ self , running_request_id : str , running_request_ids : list [str ],
176+ transfering_request_ids : list [str ]) -> int :
177+ """Get the number of common prefix blocks for all running requests."""
190178 raise NotImplementedError
191179
192180 @classmethod
@@ -289,15 +277,34 @@ def remove_skipped_blocks(self, request_id: str,
289277 # No need to remove blocks for full attention.
290278 pass
291279
292- def get_num_common_prefix_blocks (self , request_id : str ,
293- num_running_requests : int ) -> int :
294- blocks = self .req_to_blocks [request_id ]
280+ def get_num_common_prefix_blocks (
281+ self , running_request_id : str , running_request_ids : list [str ],
282+ transfering_request_ids : list [str ]) -> int :
283+ """Get common prefix blocks shared by all running and transferring
284+ requests."""
285+ if running_request_id not in self .req_to_blocks :
286+ return 0
287+
288+ all_request_ids = running_request_ids + transfering_request_ids
289+ request_blocks = [
290+ self .req_to_blocks [req_id ] for req_id in all_request_ids
291+ if req_id in self .req_to_blocks
292+ ]
293+
294+ reference_blocks = self .req_to_blocks [running_request_id ]
295+ total_requests = len (all_request_ids )
296+
295297 num_common_blocks = 0
296- for block in blocks :
297- if block .ref_cnt == num_running_requests :
298+ for i , ref_block in enumerate (reference_blocks ):
299+ requests_with_block = sum (
300+ 1 for blocks in request_blocks if i < len (blocks )
301+ and blocks [i ].block_id == ref_block .block_id )
302+
303+ if requests_with_block == total_requests :
298304 num_common_blocks += 1
299305 else :
300306 break
307+
301308 return num_common_blocks
302309
303310
@@ -390,8 +397,12 @@ def remove_skipped_blocks(self, request_id: str,
390397 blocks [i ] = self ._null_block
391398 self .block_pool .free_blocks (removed_blocks )
392399
393- def get_num_common_prefix_blocks (self , request_id : str ,
394- num_running_requests : int ) -> int :
400+ def get_num_common_prefix_blocks (
401+ self ,
402+ running_request_id : str ,
403+ running_request_ids : list [str ],
404+ transfering_request_ids : list [str ],
405+ ) -> int :
395406 """
396407 NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
397408 So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -518,8 +529,12 @@ def remove_skipped_blocks(self, request_id: str,
518529 blocks [i ] = self ._null_block
519530 self .block_pool .free_blocks (removed_blocks )
520531
521- def get_num_common_prefix_blocks (self , request_id : str ,
522- num_running_requests : int ) -> int :
532+ def get_num_common_prefix_blocks (
533+ self ,
534+ running_request_id : str ,
535+ running_request_ids : list [str ],
536+ transfering_request_ids : list [str ],
537+ ) -> int :
523538 """
524539 cascade attention is not supported by chunked local attention.
525540 """
@@ -555,8 +570,12 @@ def remove_skipped_blocks(self, request_id: str,
555570 # remove blocks.
556571 pass
557572
558- def get_num_common_prefix_blocks (self , request_id : str ,
559- num_running_requests : int ) -> int :
573+ def get_num_common_prefix_blocks (
574+ self ,
575+ running_request_id : str ,
576+ running_request_ids : list [str ],
577+ transfering_request_ids : list [str ],
578+ ) -> int :
560579 return 0
561580
562581 def get_num_blocks_to_allocate (
@@ -618,8 +637,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
618637 # requests, so this method is not relevant.
619638 raise ValueError ("Should not be called as prefix caching is disabled." )
620639
621- def get_num_common_prefix_blocks (self , request_id : str ,
622- num_running_requests : int ) -> int :
640+ def get_num_common_prefix_blocks (
641+ self , running_request_id : str , running_request_ids : list [str ],
642+ transfering_request_ids : list [str ]) -> int :
623643 # Cross-attention blocks contain request-specific encoder states
624644 # and are not shared between different requests
625645 return 0
0 commit comments