Skip to content

[P/D] Avoid stranding blocks in P when aborted in D's waiting queue #19223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,21 @@ def request_finished(
logger.debug(
"NIXLConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
if not params:
return False, None

if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
self._reqs_need_recv[request.request_id] = (request, [])
params["do_remote_prefill"] = False
return False, None

if (params is None or not params.get("do_remote_decode")
if (not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None

Expand Down