Skip to content

Commit 0d66b57

Browse files
fix(v1/kv_cache): resolve async KV transfer bug in cascade attention
* Replace ref_cnt-based common prefix detection with running request tracking * Update get_num_common_prefix_blocks() to accept running_request_ids set * Fix FullAttentionManager to count actual references from running requests * Prevent incorrect cascade attention when async KV offloading delays cleanup This resolves a bug where completed requests with pending async transfers still contributed to ref_cnt, causing incorrect cascade attention decisions. Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
1 parent 2b41cbb commit 0d66b57

File tree

4 files changed

+83
-71
lines changed

4 files changed

+83
-71
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,27 +136,29 @@ def free(self, request_id: str) -> None:
136136
for manager in self.single_type_managers:
137137
manager.free(request_id)
138138

139-
def get_num_common_prefix_blocks(self, request_id: str,
140-
num_running_requests: int) -> list[int]:
139+
def get_num_common_prefix_blocks(
140+
self, running_request_id: str, num_running_requests: int,
141+
transfering_request_ids: list[str]) -> list[int]:
141142
"""
142143
Get the number of common prefix blocks for all requests in the RUNNING
143-
state for each kv cache group.
144+
and TRANSFERING state for each kv cache group.
144145
145146
Args:
146-
request_id: The request ID.
147+
running_request_id: The request ID of the running request.
147148
num_running_requests: The total number of requests in the RUNNING
148149
state.
150+
transfering_request_ids: List of request IDs in transfer state.
149151
150152
Returns:
151153
list[int]: The number of common prefix blocks for all requests in
152154
the RUNNING state for each kv cache group.
153155
"""
154-
num_blocks_per_group = [
155-
manager.get_num_common_prefix_blocks(request_id,
156-
num_running_requests)
156+
return [
157+
manager.get_num_common_prefix_blocks(running_request_id,
158+
num_running_requests,
159+
transfering_request_ids)
157160
for manager in self.single_type_managers
158161
]
159-
return num_blocks_per_group
160162

161163
def remove_skipped_blocks(self, request_id: str,
162164
num_computed_tokens: int) -> None:
@@ -202,8 +204,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
202204
enable_kv_cache_events)
203205
self.num_single_type_manager = len(self.single_type_managers)
204206

205-
def get_num_common_prefix_blocks(self, request_id: str,
206-
num_running_requests: int) -> list[int]:
207+
def get_num_common_prefix_blocks(
208+
self, running_request_id: str, num_running_requests: int,
209+
transfering_request_ids: list[str]) -> list[int]:
207210
return [0] * self.num_single_type_manager
208211

209212
def find_longest_cache_hit(

vllm/v1/core/kv_cache_manager.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.v1.core.kv_cache_utils import KVCacheBlock
1111
from vllm.v1.kv_cache_interface import KVCacheConfig
1212
from vllm.v1.metrics.stats import PrefixCacheStats
13-
from vllm.v1.request import Request, RequestStatus
13+
from vllm.v1.request import Request
1414

1515
logger = init_logger(__name__)
1616

@@ -321,46 +321,30 @@ def reset_prefix_cache(self) -> bool:
321321

322322
def get_num_common_prefix_blocks(
323323
self,
324-
request: Request,
324+
running_request_id: str,
325325
num_running_requests: int,
326+
transfering_request_ids: list[str],
326327
) -> list[int]:
327328
"""Calculate the number of common prefix blocks shared by all requests
328-
in the RUNNING state for each kv cache group.
329-
330-
The function determines this by selecting any request and iterating
331-
through its blocks. A block is considered a common prefix block if its
332-
`ref_cnt` equals the total number of requests in the RUNNING state.
333-
334-
NOTE(woosuk): The number of requests in the RUNNING state is **greater
335-
than or equal to** the number of requests scheduled in the current step.
336-
This is because the RUNNING state only indicates that:
337-
1. The request has not yet finished, and
338-
2. The request holds its blocks unfreed.
339-
340-
While all scheduled requests must be in the RUNNING state, the inverse
341-
is not necessarily true. There may be RUNNING requests that are not
342-
scheduled in the current step.
329+
in the RUNNING state for each kv cache group. A block is considered a
330+
common prefix block if it is referenced by ALL currently running
331+
requests.
343332
344-
This can result in an edge case where the number of common prefix blocks
345-
is 0, even though all scheduled requests share a common prefix. This
346-
occurs because there may be unscheduled RUNNING requests that do not
347-
share the common prefix. Currently, this case cannot be easily detected,
348-
so the function returns 0 in such cases.
333+
This approach correctly handles async KV offloading scenarios where
334+
completed requests may still hold block references while no longer
335+
being in the RUNNING state.
349336
350337
Args:
351-
request: Any request in the RUNNING state, used to identify the
352-
common prefix blocks.
338+
running_request_id: The request ID of the running request.
353339
num_running_requests: The total number of requests in the RUNNING
354-
state. This can be different from the number of scheduled
355-
requests in the current step.
340+
state.
341+
transfering_request_ids: List of request IDs in transfer state.
356342
357343
Returns:
358-
list[int]: The number of common prefix blocks for each kv cache
359-
group.
344+
list[int]: Number of common prefix blocks for each kv cache group.
360345
"""
361-
assert request.status == RequestStatus.RUNNING
362346
return self.coordinator.get_num_common_prefix_blocks(
363-
request.request_id, num_running_requests)
347+
running_request_id, num_running_requests, transfering_request_ids)
364348

365349
def take_events(self) -> list[KVCacheEvent]:
366350
"""Take the KV cache events from the block pool.
@@ -386,4 +370,4 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
386370
def create_empty_block_list(self) -> KVCacheBlocks:
387371
"""Creates a new KVCacheBlocks instance with no blocks."""
388372
return KVCacheBlocks(tuple([]
389-
for _ in range(self.num_kv_cache_groups)))
373+
for _ in range(self.num_kv_cache_groups)))

vllm/v1/core/sched/scheduler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,18 @@ def schedule(self) -> SchedulerOutput:
542542
self.kv_cache_config.kv_cache_groups)
543543
if self.running:
544544
any_request = self.running[0]
545+
running_request_ids = {req.request_id for req in self.running}
546+
547+
# Include requests in KV transfer state for common prefix calc
548+
transferring_request_ids = [
549+
req_id for req_id, request in self.requests.items()
550+
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS and
551+
any(self.kv_cache_manager.get_blocks(req_id).get_block_ids())
552+
]
545553
num_common_prefix_blocks = (
546554
self.kv_cache_manager.get_num_common_prefix_blocks(
547-
any_request, len(self.running)))
555+
any_request.request_id, len(running_request_ids),
556+
transferring_request_ids))
548557

549558
# Construct the scheduler output.
550559
new_reqs_data = [

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,10 @@ def free(self, request_id: str) -> None:
168168
self.num_cached_block.pop(request_id, None)
169169

170170
@abstractmethod
171-
def get_num_common_prefix_blocks(self, request_id: str,
172-
num_running_requests: int) -> int:
173-
"""
174-
Get the number of common prefix blocks for all requests in the RUNNING
175-
state.
176-
177-
Args:
178-
request_id: The request ID.
179-
num_running_requests: The total number of requests in the RUNNING
180-
state.
181-
182-
Returns:
183-
The number of common prefix blocks for all requests in the RUNNING
184-
state.
185-
"""
186-
171+
def get_num_common_prefix_blocks(
172+
self, running_request_id: str, num_running_requests: int,
173+
transfering_request_ids: list[str]) -> int:
174+
"""Get the number of common prefix blocks for all running requests."""
187175
raise NotImplementedError
188176

189177
@classmethod
@@ -281,15 +269,30 @@ def remove_skipped_blocks(self, request_id: str,
281269
# No need to remove blocks for full attention.
282270
pass
283271

284-
def get_num_common_prefix_blocks(self, request_id: str,
285-
num_running_requests: int) -> int:
286-
blocks = self.req_to_blocks[request_id]
272+
def get_num_common_prefix_blocks(
273+
self, running_request_id: str, num_running_requests: int,
274+
transfering_request_ids: list[str]) -> int:
275+
"""Get common prefix blocks using ref_cnt with transferring requests."""
276+
if running_request_id not in self.req_to_blocks:
277+
return 0
278+
279+
running_blocks = self.req_to_blocks[running_request_id]
280+
transfering_blocks = [
281+
self.req_to_blocks[req_id] for req_id in transfering_request_ids
282+
if req_id in self.req_to_blocks
283+
]
284+
287285
num_common_blocks = 0
288-
for block in blocks:
289-
if block.ref_cnt == num_running_requests:
286+
for i, block in enumerate(running_blocks):
287+
num_transfering_blocks = sum(
288+
1 for blocks in transfering_blocks if i < len(blocks)
289+
and blocks[i].block_id == running_blocks[i].block_id)
290+
291+
if block.ref_cnt == num_running_requests + num_transfering_blocks:
290292
num_common_blocks += 1
291293
else:
292294
break
295+
293296
return num_common_blocks
294297

295298

@@ -380,8 +383,12 @@ def remove_skipped_blocks(self, request_id: str,
380383
blocks[i] = self._null_block
381384
self.block_pool.free_blocks(removed_blocks)
382385

383-
def get_num_common_prefix_blocks(self, request_id: str,
384-
num_running_requests: int) -> int:
386+
def get_num_common_prefix_blocks(
387+
self,
388+
running_request_id: str,
389+
num_running_requests: int,
390+
transfering_request_ids: list[str],
391+
) -> int:
385392
"""
386393
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
387394
So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -506,8 +513,12 @@ def remove_skipped_blocks(self, request_id: str,
506513
blocks[i] = self._null_block
507514
self.block_pool.free_blocks(removed_blocks)
508515

509-
def get_num_common_prefix_blocks(self, request_id: str,
510-
num_running_requests: int) -> int:
516+
def get_num_common_prefix_blocks(
517+
self,
518+
running_request_id: str,
519+
num_running_requests: int,
520+
transfering_request_ids: list[str],
521+
) -> int:
511522
"""
512523
cascade attention is not supported by chunked local attention.
513524
"""
@@ -541,8 +552,12 @@ def remove_skipped_blocks(self, request_id: str,
541552
# remove blocks.
542553
pass
543554

544-
def get_num_common_prefix_blocks(self, request_id: str,
545-
num_running_requests: int) -> int:
555+
def get_num_common_prefix_blocks(
556+
self,
557+
running_request_id: str,
558+
num_running_requests: int,
559+
transfering_request_ids: list[str],
560+
) -> int:
546561
return 0
547562

548563
def allocate_new_blocks(self, request_id: str,
@@ -568,8 +583,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
568583
# requests, so this method is not relevant.
569584
raise ValueError("Should not be called as prefix caching is disabled.")
570585

571-
def get_num_common_prefix_blocks(self, request_id: str,
572-
num_running_requests: int) -> int:
586+
def get_num_common_prefix_blocks(
587+
self, running_request_id: str, num_running_requests: int,
588+
transfering_request_ids: list[str]) -> int:
573589
# Cross-attention blocks contain request-specific encoder states
574590
# and are not shared between different requests
575591
return 0

0 commit comments

Comments
 (0)