Skip to content

[Core] Fix abrupt request abort #18485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
Get the blocks for the request.
"""
return [
manager.req_to_blocks[request_id]
manager.req_to_blocks.get(request_id) or []
for manager in self.single_type_managers
]

Expand Down
14 changes: 8 additions & 6 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
if self.vllm_config.kv_transfer_config is not None:
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)

Expand Down Expand Up @@ -985,9 +988,8 @@ def _connector_finished(
"""
if self.connector is None:
return False, None
assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]

(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)

def _update_waiting_for_remote_kv(self, request: Request) -> bool:
Expand All @@ -1002,12 +1004,12 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
assert self.connector is not None
if request.request_id not in self.finished_recving_kv_req_ids:
return False
assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"

# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
# Handle the case where num request tokens less then one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
Expand Down