Skip to content

Commit 912667d

Browse files
fix(v1/kv_cache): resolve async KV transfer bug in cascade attention
* Replace ref_cnt-based common prefix detection with running request tracking * Update get_num_common_prefix_blocks() to accept running_request_ids set * Fix FullAttentionManager to count actual references from running requests * Prevent incorrect cascade attention when async KV offloading delays cleanup This resolves a bug where completed requests with pending async transfers still contributed to ref_cnt, causing incorrect cascade attention decisions. Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
1 parent 185d8ed commit 912667d

File tree

4 files changed

+83
-41
lines changed

4 files changed

+83
-41
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,26 +149,31 @@ def free(self, request_id: str) -> None:
149149
manager.free(request_id)
150150

151151
def get_num_common_prefix_blocks(
152-
self, request_id: str, num_running_requests: int
152+
self,
153+
running_request_id: str,
154+
num_running_requests: int,
155+
transfering_request_ids: list[str],
153156
) -> list[int]:
154157
"""
155158
Get the number of common prefix blocks for all requests in the RUNNING
156159
state for each kv cache group.
157160
158161
Args:
159-
request_id: The request ID.
160-
num_running_requests: The total number of requests in the RUNNING
161-
state.
162+
running_request_id: The request ID of the running request.
163+
num_running_requests: The number of requests in the RUNNING state.
164+
transfering_request_ids: List of request IDs in
165+
WAITING_FOR_REMOTE_KVS state.
162166
163167
Returns:
164168
list[int]: The number of common prefix blocks for all requests in
165169
the RUNNING state for each kv cache group.
166170
"""
167-
num_blocks_per_group = [
168-
manager.get_num_common_prefix_blocks(request_id, num_running_requests)
171+
return [
172+
manager.get_num_common_prefix_blocks(
173+
running_request_id, num_running_requests, transfering_request_ids
174+
)
169175
for manager in self.single_type_managers
170176
]
171-
return num_blocks_per_group
172177

173178
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
174179
"""
@@ -227,7 +232,10 @@ def __init__(
227232
self.num_single_type_manager = len(self.single_type_managers)
228233

229234
def get_num_common_prefix_blocks(
230-
self, request_id: str, num_running_requests: int
235+
self,
236+
running_request_id: str,
237+
num_running_requests: int,
238+
transfering_request_ids: list[str],
231239
) -> list[int]:
232240
return [0] * self.num_single_type_manager
233241

vllm/v1/core/kv_cache_manager.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.v1.core.kv_cache_utils import KVCacheBlock
1111
from vllm.v1.kv_cache_interface import KVCacheConfig
1212
from vllm.v1.metrics.stats import PrefixCacheStats
13-
from vllm.v1.request import Request, RequestStatus
13+
from vllm.v1.request import Request
1414

1515
logger = init_logger(__name__)
1616

@@ -346,15 +346,18 @@ def reset_prefix_cache(self) -> bool:
346346

347347
def get_num_common_prefix_blocks(
348348
self,
349-
request: Request,
349+
running_request_id: str,
350350
num_running_requests: int,
351+
transfering_request_ids: list[str],
351352
) -> list[int]:
352353
"""Calculate the number of common prefix blocks shared by all requests
353354
in the RUNNING state for each kv cache group.
354355
355-
The function determines this by selecting any request and iterating
356-
through its blocks. A block is considered a common prefix block if its
357-
`ref_cnt` equals the total number of requests in the RUNNING state.
356+
The function determines this by selecting any running request and
357+
iterating through its blocks. A block is considered a common prefix
358+
block if it is shared by ALL currently running requests. Transferring
359+
requests (those in WAITING_FOR_REMOTE_KVS state) are excluded from
360+
this check, as they may not have fully loaded their KV cache yet.
358361
359362
NOTE(woosuk): The number of requests in the RUNNING state is **greater
360363
than or equal to** the number of requests scheduled in the current step.
@@ -373,19 +376,20 @@ def get_num_common_prefix_blocks(
373376
so the function returns 0 in such cases.
374377
375378
Args:
376-
request: Any request in the RUNNING state, used to identify the
377-
common prefix blocks.
379+
running_request_id: The request ID of any running request, used to
380+
identify the common prefix blocks.
378381
num_running_requests: The total number of requests in the RUNNING
379382
state. This can be different from the number of scheduled
380383
requests in the current step.
384+
transfering_request_ids: List of request IDs in transfer state
385+
(WAITING_FOR_REMOTE_KVS).
381386
382387
Returns:
383388
list[int]: The number of common prefix blocks for each kv cache
384389
group.
385390
"""
386-
assert request.status == RequestStatus.RUNNING
387391
return self.coordinator.get_num_common_prefix_blocks(
388-
request.request_id, num_running_requests
392+
running_request_id, num_running_requests, transfering_request_ids
389393
)
390394

391395
def take_events(self) -> list[KVCacheEvent]:

vllm/v1/core/sched/scheduler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,18 @@ def schedule(self) -> SchedulerOutput:
595595
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
596596
if self.running:
597597
any_request = self.running[0]
598+
num_running_requests = len(self.running)
599+
600+
transferring_request_ids = [
601+
req_id
602+
for req_id, request in self.requests.items()
603+
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
604+
]
598605
num_common_prefix_blocks = (
599606
self.kv_cache_manager.get_num_common_prefix_blocks(
600-
any_request, len(self.running)
607+
any_request.request_id,
608+
num_running_requests,
609+
transferring_request_ids,
601610
)
602611
)
603612

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -183,22 +183,12 @@ def free(self, request_id: str) -> None:
183183

184184
@abstractmethod
185185
def get_num_common_prefix_blocks(
186-
self, request_id: str, num_running_requests: int
186+
self,
187+
running_request_id: str,
188+
num_running_requests: int,
189+
transfering_request_ids: list[str],
187190
) -> int:
188-
"""
189-
Get the number of common prefix blocks for all requests in the RUNNING
190-
state.
191-
192-
Args:
193-
request_id: The request ID.
194-
num_running_requests: The total number of requests in the RUNNING
195-
state.
196-
197-
Returns:
198-
The number of common prefix blocks for all requests in the RUNNING
199-
state.
200-
"""
201-
191+
"""Get the number of common prefix blocks for all running requests."""
202192
raise NotImplementedError
203193

204194
@classmethod
@@ -303,15 +293,34 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
303293
pass
304294

305295
def get_num_common_prefix_blocks(
306-
self, request_id: str, num_running_requests: int
296+
self,
297+
running_request_id: str,
298+
num_running_requests: int,
299+
transfering_request_ids: list[str],
307300
) -> int:
308-
blocks = self.req_to_blocks[request_id]
301+
"""Get common prefix blocks shared by all running requests."""
302+
303+
reference_blocks = self.req_to_blocks[running_request_id]
304+
305+
transferring_blocks = [
306+
self.req_to_blocks[req_id]
307+
for req_id in transfering_request_ids
308+
if req_id in self.req_to_blocks
309+
]
310+
309311
num_common_blocks = 0
310-
for block in blocks:
311-
if block.ref_cnt == num_running_requests:
312+
for i, ref_block in enumerate(reference_blocks):
313+
transferring_has_block = sum(
314+
1
315+
for blocks in transferring_blocks
316+
if i < len(blocks) and blocks[i].block_id == ref_block.block_id
317+
)
318+
319+
if ref_block.ref_cnt - transferring_has_block == num_running_requests:
312320
num_common_blocks += 1
313321
else:
314322
break
323+
315324
return num_common_blocks
316325

317326

@@ -409,7 +418,10 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
409418
self.block_pool.free_blocks(removed_blocks)
410419

411420
def get_num_common_prefix_blocks(
412-
self, request_id: str, num_running_requests: int
421+
self,
422+
running_request_id: str,
423+
num_running_requests: int,
424+
transfering_request_ids: list[str],
413425
) -> int:
414426
"""
415427
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
@@ -545,7 +557,10 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
545557
self.block_pool.free_blocks(removed_blocks)
546558

547559
def get_num_common_prefix_blocks(
548-
self, request_id: str, num_running_requests: int
560+
self,
561+
running_request_id: str,
562+
num_running_requests: int,
563+
transfering_request_ids: list[str],
549564
) -> int:
550565
"""
551566
cascade attention is not supported by chunked local attention.
@@ -597,7 +612,10 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
597612
pass
598613

599614
def get_num_common_prefix_blocks(
600-
self, request_id: str, num_running_requests: int
615+
self,
616+
running_request_id: str,
617+
num_running_requests: int,
618+
transfering_request_ids: list[str],
601619
) -> int:
602620
"""
603621
cascade attention is not supported by mamba
@@ -649,7 +667,10 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
649667
raise ValueError("Should not be called as prefix caching is disabled.")
650668

651669
def get_num_common_prefix_blocks(
652-
self, request_id: str, num_running_requests: int
670+
self,
671+
running_request_id: str,
672+
num_running_requests: int,
673+
transfering_request_ids: list[str],
653674
) -> int:
654675
# Cross-attention blocks contain request-specific encoder states
655676
# and are not shared between different requests

0 commit comments

Comments
 (0)