Skip to content

Commit eeaddfc

Browse files
njhillgmarinho2
authored andcommitted
[P/D] Avoid stranding blocks in P when aborted in D's waiting queue (vllm-project#19223)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 60466f1 commit eeaddfc

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,21 @@ def request_finished(
298298
logger.debug(
299299
"NIXLConnector request_finished, request_status=%s, "
300300
"kv_transfer_params=%s", request.status, params)
301+
if not params:
302+
return False, None
303+
304+
if params.get("do_remote_prefill"):
305+
# If do_remote_prefill is still True when the request is finished,
306+
# update_state_after_alloc must not have been called (the request
307+
# must have been aborted before it was scheduled).
308+
# To avoid stranding the prefill blocks in the prefill instance,
309+
# we must add empty block_ids to _reqs_need_recv so that our
310+
# worker side will notify and free blocks in the prefill instance.
311+
self._reqs_need_recv[request.request_id] = (request, [])
312+
params["do_remote_prefill"] = False
313+
return False, None
301314

302-
if (params is None or not params.get("do_remote_decode")
315+
if (not params.get("do_remote_decode")
303316
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
304317
return False, None
305318

0 commit comments

Comments
 (0)