@@ -168,22 +168,10 @@ def free(self, request_id: str) -> None:
168168 self .num_cached_block .pop (request_id , None )
169169
170170 @abstractmethod
171- def get_num_common_prefix_blocks (self , request_id : str ,
172- num_running_requests : int ) -> int :
173- """
174- Get the number of common prefix blocks for all requests in the RUNNING
175- state.
176-
177- Args:
178- request_id: The request ID.
179- num_running_requests: The total number of requests in the RUNNING
180- state.
181-
182- Returns:
183- The number of common prefix blocks for all requests in the RUNNING
184- state.
185- """
186-
171+ def get_num_common_prefix_blocks (
172+ self , running_request_id : str , num_running_requests : int ,
173+ transfering_request_ids : list [str ]) -> int :
174+ """Get the number of common prefix blocks for all running requests."""
187175 raise NotImplementedError
188176
189177 @classmethod
@@ -281,15 +269,30 @@ def remove_skipped_blocks(self, request_id: str,
281269 # No need to remove blocks for full attention.
282270 pass
283271
284- def get_num_common_prefix_blocks (self , request_id : str ,
285- num_running_requests : int ) -> int :
286- blocks = self .req_to_blocks [request_id ]
272+ def get_num_common_prefix_blocks (
273+ self , running_request_id : str , num_running_requests : int ,
274+ transfering_request_ids : list [str ]) -> int :
275+ """Get common prefix blocks using ref_cnt with transferring requests."""
276+ if running_request_id not in self .req_to_blocks :
277+ return 0
278+
279+ running_blocks = self .req_to_blocks [running_request_id ]
280+ transfering_blocks = [
281+ self .req_to_blocks [req_id ] for req_id in transfering_request_ids
282+ if req_id in self .req_to_blocks
283+ ]
284+
287285 num_common_blocks = 0
288- for block in blocks :
289- if block .ref_cnt == num_running_requests :
286+ for i , block in enumerate (running_blocks ):
287+ num_transfering_blocks = sum (
288+ 1 for blocks in transfering_blocks if i < len (blocks )
289+ and blocks [i ].block_id == running_blocks [i ].block_id )
290+
291+ if block .ref_cnt == num_running_requests + num_transfering_blocks :
290292 num_common_blocks += 1
291293 else :
292294 break
295+
293296 return num_common_blocks
294297
295298
@@ -380,8 +383,12 @@ def remove_skipped_blocks(self, request_id: str,
380383 blocks [i ] = self ._null_block
381384 self .block_pool .free_blocks (removed_blocks )
382385
383- def get_num_common_prefix_blocks (self , request_id : str ,
384- num_running_requests : int ) -> int :
386+ def get_num_common_prefix_blocks (
387+ self ,
388+ running_request_id : str ,
389+ num_running_requests : int ,
390+ transfering_request_ids : list [str ],
391+ ) -> int :
385392 """
386393 NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
387394 So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -506,8 +513,12 @@ def remove_skipped_blocks(self, request_id: str,
506513 blocks [i ] = self ._null_block
507514 self .block_pool .free_blocks (removed_blocks )
508515
509- def get_num_common_prefix_blocks (self , request_id : str ,
510- num_running_requests : int ) -> int :
516+ def get_num_common_prefix_blocks (
517+ self ,
518+ running_request_id : str ,
519+ num_running_requests : int ,
520+ transfering_request_ids : list [str ],
521+ ) -> int :
511522 """
512523 cascade attention is not supported by chunked local attention.
513524 """
@@ -541,8 +552,12 @@ def remove_skipped_blocks(self, request_id: str,
541552 # remove blocks.
542553 pass
543554
544- def get_num_common_prefix_blocks (self , request_id : str ,
545- num_running_requests : int ) -> int :
555+ def get_num_common_prefix_blocks (
556+ self ,
557+ running_request_id : str ,
558+ num_running_requests : int ,
559+ transfering_request_ids : list [str ],
560+ ) -> int :
546561 return 0
547562
548563 def allocate_new_blocks (self , request_id : str ,
@@ -568,8 +583,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
568583 # requests, so this method is not relevant.
569584 raise ValueError ("Should not be called as prefix caching is disabled." )
570585
571- def get_num_common_prefix_blocks (self , request_id : str ,
572- num_running_requests : int ) -> int :
586+ def get_num_common_prefix_blocks (
587+ self , running_request_id : str , num_running_requests : int ,
588+ transfering_request_ids : list [str ]) -> int :
573589 # Cross-attention blocks contain request-specific encoder states
574590 # and are not shared between different requests
575591 return 0
0 commit comments