@@ -285,6 +285,29 @@ def free(self, request: Request) -> None:
285285 if block .ref_cnt == 0 :
286286 self .free_block_queue .append (block )
287287
288+ def uncache_blocks (self , request : Request ) -> int :
289+ """Uncache the blocks that are no longer full based on the
290+ num_computed_tokens in the given request. This happens when
291+ the blocks were full and cached due to speculative tokens, but the
292+ speculative tokens are not accepted.
293+
294+ Args:
295+ request: The request.
296+
297+ Returns:
298+ The number of uncached blocks.
299+ """
300+ blocks = self .req_to_blocks [request .request_id ]
301+ num_computed_tokens = request .num_computed_tokens
302+ num_full_blocks = num_computed_tokens // self .block_size
303+ num_uncached_blocks = 0
304+ for block in blocks [num_full_blocks :]:
305+ # If the block is not cached, the following blocks are not cached.
306+ if not self ._maybe_evict_cached_block (block ):
307+ break
308+ num_uncached_blocks += 1
309+ return num_uncached_blocks
310+
288311 def reset_prefix_cache (self ) -> bool :
289312 """Reset prefix cache. This function may be used in RLHF
290313 flows to invalid prefix caching after the weights are updated,
@@ -386,21 +409,24 @@ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
386409
387410 # If the block is cached, evict it.
388411 if self .enable_caching :
389- self ._evict_cached_block (curr_block )
412+ self ._maybe_evict_cached_block (curr_block )
390413
391414 curr_block .incr_ref ()
392415 ret .append (curr_block )
393416 idx += 1
394417
395418 return ret
396419
397- def _evict_cached_block (self , block : KVCacheBlock ) -> None :
420+ def _maybe_evict_cached_block (self , block : KVCacheBlock ) -> bool :
398421 """
399422 If a block is cached in `cached_block_hash_to_block`, we reset its hash
400423 metadata and evict it from the cache.
401424
402425 Args:
403426 block: The block to evict.
427+
428+ Returns:
429+ True if the block is evicted, False otherwise.
404430 """
405431 block_hash = block .block_hash
406432 if block_hash and block_hash in self .cached_block_hash_to_block :
@@ -410,6 +436,9 @@ def _evict_cached_block(self, block: KVCacheBlock) -> None:
410436 if len (self .cached_block_hash_to_block [block_hash ]) == 0 :
411437 del self .cached_block_hash_to_block [block_hash ]
412438
439+ return True
440+ return False
441+
413442 def _get_cached_block (self ,
414443 block_hash : BlockHashType ) -> Optional [KVCacheBlock ]:
415444 """Get a cached block by the block hash, or None if cache miss.
0 commit comments