-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[PD][Nixl] Remote consumer READ timeout for clearing request blocks #20139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import contextlib | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import copy | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import math | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import queue | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import threading | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -79,15 +80,16 @@ class ReqMeta: | |||||||||||||||||||||||||||||||||||||||||||||||||||
class NixlConnectorMetadata(KVConnectorMetadata): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.requests: dict[ReqId, ReqMeta] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.reqs_to_send: set[str] = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def add_new_req( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
request_id: ReqId, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
local_block_ids: list[int], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
kv_transfer_params: dict[str, Any], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.requests[request_id] = ReqMeta( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self.reqs_to_recv[request_id] = ReqMeta( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
local_block_ids=local_block_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
remote_block_ids=kv_transfer_params["remote_block_ids"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
remote_engine_id=kv_transfer_params["remote_engine_id"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -194,10 +196,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): | |||||||||||||||||||||||||||||||||||||||||||||||||||
vllm_config.parallel_config.tensor_parallel_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.info("Initializing NIXL Scheduler %s", engine_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Requests that need to start recv. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Requests that need to start recv/send. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# New requests are added by update_state_after_alloc in | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# the scheduler. Used to make metadata passed to Worker. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_need_send: set[str] = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_num_new_matched_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, request: "Request", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -265,6 +268,9 @@ def update_state_after_alloc(self, request: "Request", | |||||||||||||||||||||||||||||||||||||||||||||||||||
assert num_external_tokens == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Only trigger 1 KV transfer per request. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
params["do_remote_prefill"] = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||
elif params is not None and params.get("do_remote_decode"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Prefill request on remote. It will be read from D upon completion | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_need_send.add(request.request_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def build_connector_meta( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -281,8 +287,10 @@ def build_connector_meta( | |||||||||||||||||||||||||||||||||||||||||||||||||||
kv_transfer_params=req.kv_transfer_params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
meta.reqs_to_send = copy.copy(self._reqs_need_send) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Clear the list once workers start the transfers | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_need_recv.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_need_send.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+290
to
+293
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can avoid copying
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
return meta | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -394,6 +402,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): | |||||||||||||||||||||||||||||||||||||||||||||||||||
# In progress transfers. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# [req_id -> list[handle]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Keep track of the time for requests that are waiting to be sent. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_to_send: dict[ReqId, float] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Complete transfer tracker. Used by the rank 0 to track finished | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# transactions on ranks 1 to N-1. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -799,6 +809,24 @@ def get_finished(self) -> tuple[set[str], set[str]]: | |||||||||||||||||||||||||||||||||||||||||||||||||||
"and %s requests done recving", self.tp_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
len(done_sending), len(done_recving)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Handle timeout to avoid stranding blocks on remote. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
now = time.monotonic() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
timed_out_requests: list[str] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for req_id, finish_time in self._reqs_to_send.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if finish_time < 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Request just finished, start timeout. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_to_send[req_id] = now | ||||||||||||||||||||||||||||||||||||||||||||||||||||
elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Timeout exceed, clear the request blocks. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
timed_out_requests.append(req_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
for req_id in timed_out_requests: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Skip communication with other ranks, but | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.tp_rank == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._done_sending_count[req_id] += self.world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||
done_sending.add(req_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
del self._reqs_to_send[req_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+815
to
+828
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The timeout mechanism implemented here relies on import time
# Use time.monotonic() instead of time.perf_counter()
now = time.monotonic()
timed_out_requests: list[str] = []
for req_id, finish_time in self._reqs_to_send.items():
if finish_time < 0:
# Request just finished, start timeout.
self._reqs_to_send[req_id] = now
elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT:
# Timeout exceed, clear the request blocks.
timed_out_requests.append(req_id)
Comment on lines
+812
to
+828
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to keep things simple and omit the TP optimization. I think we'll likely make this logic generic and move it outside of the connector impl anyhow (aggregating the finished events in TP case). Dicts are ordered so we only need to peek the oldest entry.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.world_size == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return done_sending, done_recving | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -830,7 +858,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
all_done_sending: set[str] = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for req_id in list(self._done_sending_count.keys()): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._done_sending_count[req_id] == self.world_size: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._done_sending_count[req_id] >= self.world_size: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
del self._done_sending_count[req_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
all_done_sending.add(req_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -860,6 +888,7 @@ def _get_new_notifs(self) -> set[str]: | |||||||||||||||||||||||||||||||||||||||||||||||||||
tp_ratio): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
notified_req_ids.add(req_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
del self.consumer_notification_counts_by_req[req_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
del self._reqs_to_send[req_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return notified_req_ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def _pop_done_transfers( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -894,7 +923,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): | |||||||||||||||||||||||||||||||||||||||||||||||||||
Start loading by triggering non-blocking nixl_xfer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
We check for these trnxs to complete in each step(). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for req_id, meta in metadata.requests.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for req_id, meta in metadata.reqs_to_recv.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
remote_engine_id = meta.remote_engine_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"start_load_kv for request %s from remote engine %s. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -939,6 +968,11 @@ def request_ready(_f: Future[Any], | |||||||||||||||||||||||||||||||||||||||||||||||||||
while not self._ready_requests.empty(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._read_blocks_for_req(*self._ready_requests.get_nowait()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Track the request that are waiting to be read and abort on timeout. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# Set to -1 so that timeout does not depend on model latency. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
for req_id in metadata.reqs_to_send: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._reqs_to_send[req_id] = -1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+971
to
+974
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this would be needed per my other comments |
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"Remote agent %s available, calling _read_blocks for req %s", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1043,8 +1077,7 @@ def _read_blocks(self, local_block_ids: list[int], | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# Use handle to check completion in future step(). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
# TODO (NickLucche) surface xfer elapsed time | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._recving_transfers[request_id].append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
(handle, time.perf_counter())) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self._recving_transfers[request_id].append((handle, time.monotonic())) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def _get_block_descs_ids(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
engine_id: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should go in
request_finished()
, and only do it if we returnTrue
from that.We can also set the absolute deadline at this point (set can be of tuple(req_id, deadline)), and include it in the transfer params that are returned (so the D worker can check it in it's
get_num_matched_tokens
method).And clearer to set a deadline than the finished time... but should include some buffer to allow for transfer time and slightly misaligned clocks .. e.g. 60sec deadline for D side, 90 sec expiry on P side.