Skip to content

Commit cd98905

Browse files
fix(v1/kv_cache): resolve async KV transfer bug in cascade attention (#23485)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
1 parent 067da2d commit cd98905

File tree

4 files changed

+41
-72
lines changed

4 files changed

+41
-72
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,27 +148,22 @@ def free(self, request_id: str) -> None:
148148
for manager in self.single_type_managers:
149149
manager.free(request_id)
150150

151-
def get_num_common_prefix_blocks(
152-
self, request_id: str, num_running_requests: int
153-
) -> list[int]:
151+
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
154152
"""
155-
Get the number of common prefix blocks for all requests in the RUNNING
156-
state for each kv cache group.
153+
Get the number of common prefix blocks for all requests with allocated
154+
KV cache for each kv cache group.
157155
158156
Args:
159-
request_id: The request ID.
160-
num_running_requests: The total number of requests in the RUNNING
161-
state.
157+
running_request_id: The request ID of any running request, used to
158+
identify the common prefix blocks.
162159
163160
Returns:
164-
list[int]: The number of common prefix blocks for all requests in
165-
the RUNNING state for each kv cache group.
161+
list[int]: The number of common prefix blocks for each kv cache group.
166162
"""
167-
num_blocks_per_group = [
168-
manager.get_num_common_prefix_blocks(request_id, num_running_requests)
163+
return [
164+
manager.get_num_common_prefix_blocks(running_request_id)
169165
for manager in self.single_type_managers
170166
]
171-
return num_blocks_per_group
172167

173168
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
174169
"""
@@ -226,9 +221,7 @@ def __init__(
226221
)
227222
self.num_single_type_manager = len(self.single_type_managers)
228223

229-
def get_num_common_prefix_blocks(
230-
self, request_id: str, num_running_requests: int
231-
) -> list[int]:
224+
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
232225
return [0] * self.num_single_type_manager
233226

234227
def find_longest_cache_hit(

vllm/v1/core/kv_cache_manager.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.v1.core.kv_cache_utils import KVCacheBlock
1111
from vllm.v1.kv_cache_interface import KVCacheConfig
1212
from vllm.v1.metrics.stats import PrefixCacheStats
13-
from vllm.v1.request import Request, RequestStatus
13+
from vllm.v1.request import Request
1414

1515
logger = init_logger(__name__)
1616

@@ -344,49 +344,39 @@ def reset_prefix_cache(self) -> bool:
344344
self.prefix_cache_stats.reset = True
345345
return True
346346

347-
def get_num_common_prefix_blocks(
348-
self,
349-
request: Request,
350-
num_running_requests: int,
351-
) -> list[int]:
352-
"""Calculate the number of common prefix blocks shared by all requests
353-
in the RUNNING state for each kv cache group.
347+
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
348+
"""Calculate the number of common prefix blocks for each kv cache group.
354349
355-
The function determines this by selecting any request and iterating
356-
through its blocks. A block is considered a common prefix block if its
357-
`ref_cnt` equals the total number of requests in the RUNNING state.
350+
The function selects a running request and iterates through its blocks.
351+
A block is considered a common prefix block if ALL requests with
352+
allocated KV cache share it (i.e., ref_cnt equals the number of entries
353+
in req_to_blocks).
358354
359-
NOTE(woosuk): The number of requests in the RUNNING state is **greater
355+
NOTE(woosuk): The number of requests with allocated KV cache is **greater
360356
than or equal to** the number of requests scheduled in the current step.
361-
This is because the RUNNING state only indicates that:
357+
This is because having allocated KV cache only indicates that:
362358
1. The request has not yet finished, and
363359
2. The request holds its blocks unfreed.
364360
365-
While all scheduled requests must be in the RUNNING state, the inverse
366-
is not necessarily true. There may be RUNNING requests that are not
367-
scheduled in the current step.
361+
While all scheduled requests must have allocated KV cache, the inverse
362+
is not necessarily true. There may be requests with allocated KV cache
363+
that are not scheduled in the current step.
368364
369365
This can result in an edge case where the number of common prefix blocks
370366
is 0, even though all scheduled requests share a common prefix. This
371-
occurs because there may be unscheduled RUNNING requests that do not
372-
share the common prefix. Currently, this case cannot be easily detected,
373-
so the function returns 0 in such cases.
367+
occurs because there may be unscheduled requests that do not share the
368+
common prefix. Currently, this case cannot be easily detected, so the
369+
function returns 0 in such cases.
374370
375371
Args:
376-
request: Any request in the RUNNING state, used to identify the
377-
common prefix blocks.
378-
num_running_requests: The total number of requests in the RUNNING
379-
state. This can be different from the number of scheduled
380-
requests in the current step.
372+
running_request_id: The request ID of any running request, used to
373+
identify the common prefix blocks.
381374
382375
Returns:
383376
list[int]: The number of common prefix blocks for each kv cache
384377
group.
385378
"""
386-
assert request.status == RequestStatus.RUNNING
387-
return self.coordinator.get_num_common_prefix_blocks(
388-
request.request_id, num_running_requests
389-
)
379+
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
390380

391381
def take_events(self) -> list[KVCacheEvent]:
392382
"""Take the KV cache events from the block pool.

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def schedule(self) -> SchedulerOutput:
597597
any_request = self.running[0]
598598
num_common_prefix_blocks = (
599599
self.kv_cache_manager.get_num_common_prefix_blocks(
600-
any_request, len(self.running)
600+
any_request.request_id
601601
)
602602
)
603603

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,17 @@ def free(self, request_id: str) -> None:
182182
self.num_cached_block.pop(request_id, None)
183183

184184
@abstractmethod
185-
def get_num_common_prefix_blocks(
186-
self, request_id: str, num_running_requests: int
187-
) -> int:
185+
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
188186
"""
189-
Get the number of common prefix blocks for all requests in the RUNNING
190-
state.
187+
Get the number of common prefix blocks for all requests with allocated
188+
KV cache.
191189
192190
Args:
193-
request_id: The request ID.
194-
num_running_requests: The total number of requests in the RUNNING
195-
state.
191+
running_request_id: The request ID.
196192
197193
Returns:
198-
The number of common prefix blocks for all requests in the RUNNING
199-
state.
194+
The number of common prefix blocks for all requests with allocated
195+
KV cache.
200196
"""
201197

202198
raise NotImplementedError
@@ -302,13 +298,11 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
302298
# No need to remove blocks for full attention.
303299
pass
304300

305-
def get_num_common_prefix_blocks(
306-
self, request_id: str, num_running_requests: int
307-
) -> int:
308-
blocks = self.req_to_blocks[request_id]
301+
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
302+
blocks = self.req_to_blocks[running_request_id]
309303
num_common_blocks = 0
310304
for block in blocks:
311-
if block.ref_cnt == num_running_requests:
305+
if block.ref_cnt == len(self.req_to_blocks):
312306
num_common_blocks += 1
313307
else:
314308
break
@@ -408,9 +402,7 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
408402
blocks[i] = self._null_block
409403
self.block_pool.free_blocks(removed_blocks)
410404

411-
def get_num_common_prefix_blocks(
412-
self, request_id: str, num_running_requests: int
413-
) -> int:
405+
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
414406
"""
415407
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
416408
So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -544,9 +536,7 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
544536
blocks[i] = self._null_block
545537
self.block_pool.free_blocks(removed_blocks)
546538

547-
def get_num_common_prefix_blocks(
548-
self, request_id: str, num_running_requests: int
549-
) -> int:
539+
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
550540
"""
551541
cascade attention is not supported by chunked local attention.
552542
"""
@@ -596,9 +586,7 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No
596586
# (for which find_longest_cache_hit returns block_pool.null_block)
597587
pass
598588

599-
def get_num_common_prefix_blocks(
600-
self, request_id: str, num_running_requests: int
601-
) -> int:
589+
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
602590
"""
603591
cascade attention is not supported by mamba
604592
"""
@@ -648,9 +636,7 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
648636
# requests, so this method is not relevant.
649637
raise ValueError("Should not be called as prefix caching is disabled.")
650638

651-
def get_num_common_prefix_blocks(
652-
self, request_id: str, num_running_requests: int
653-
) -> int:
639+
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
654640
# Cross-attention blocks contain request-specific encoder states
655641
# and are not shared between different requests
656642
return 0

0 commit comments

Comments
 (0)