Skip to content

[PD][Nixl] Remote consumer READ timeout for clearing request blocks #20139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import pytest

from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams

from .utils import create_request, create_scheduler, create_vllm_config

Expand Down Expand Up @@ -371,3 +374,67 @@ def test_concurrent_load_kv(
if cnt_finished_reqs == total_reqs:
return
raise TimeoutError("Took too long to complete async handshake.")


def test_abort_timeout_on_prefiller(monkeypatch):
"""
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
-----> P
| {process request}
<-\--- | {result is NOT delivered, eg proxy is down}
|
|
| {eventually free blocks}
"""
model_name = "Qwen/Qwen3-0.6B"
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)
timeout = 6
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
llm = LLM(
model=model_name,
enforce_eager=True,
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
)
remote_prefill_opts = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
# Simulate sidecar request
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=1,
extra_args={"kv_transfer_params": remote_prefill_opts})
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks

padding = "Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D."
_ = llm.generate([f"What is the capital of Japan? {padding}"],
sampling_params)

# Request finished but not freed
assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks
# Some other request, 0 still not freed
_ = llm.generate([f"What is the capital of Italy? {padding}"],
sampling_params)
assert '0' in req_to_blocks
assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks

# Wait for timeout and trigger another scheduler loop
time.sleep(timeout)
_ = llm.generate([f"What is the capital of France? {padding}"],
sampling_params)
# Request-0 times out and is cleared!
assert '0' not in req_to_blocks
47 changes: 40 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import math
import queue
import threading
Expand Down Expand Up @@ -79,15 +80,16 @@ class ReqMeta:
class NixlConnectorMetadata(KVConnectorMetadata):

def __init__(self):
self.requests: dict[ReqId, ReqMeta] = {}
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: set[str] = set()

def add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
):
self.requests[request_id] = ReqMeta(
self.reqs_to_recv[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
Expand Down Expand Up @@ -194,10 +196,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
vllm_config.parallel_config.tensor_parallel_size)
logger.info("Initializing NIXL Scheduler %s", engine_id)

# Requests that need to start recv.
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_send: set[str] = set()

def get_num_new_matched_tokens(
self, request: "Request",
Expand Down Expand Up @@ -265,6 +268,9 @@ def update_state_after_alloc(self, request: "Request",
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
elif params is not None and params.get("do_remote_decode"):
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send.add(request.request_id)
Comment on lines +271 to +273
Copy link
Member

@njhill njhill Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should go in request_finished(), and only do it if we return True from that.

We can also set the absolute deadline at this point (set can be of tuple(req_id, deadline)), and include it in the transfer params that are returned (so the D worker can check it in it's get_num_matched_tokens method).

And clearer to set a deadline than the finished time... but should include some buffer to allow for transfer time and slightly misaligned clocks .. e.g. 60sec deadline for D side, 90 sec expiry on P side.


def build_connector_meta(
self,
Expand All @@ -281,8 +287,10 @@ def build_connector_meta(
kv_transfer_params=req.kv_transfer_params,
)

meta.reqs_to_send = copy.copy(self._reqs_need_send)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_send.clear()
Comment on lines +290 to +293
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can avoid copying

Suggested change
meta.reqs_to_send = copy.copy(self._reqs_need_send)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_send.clear()
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
# Transfer reqs to send to the metadata
meta.reqs_to_send = self._reqs_need_send
self._reqs_need_send = set()


return meta

Expand Down Expand Up @@ -394,6 +402,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Keep track of the time for requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}

# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
Expand Down Expand Up @@ -799,6 +809,24 @@ def get_finished(self) -> tuple[set[str], set[str]]:
"and %s requests done recving", self.tp_rank,
len(done_sending), len(done_recving))

# Handle timeout to avoid stranding blocks on remote.
now = time.monotonic()
timed_out_requests: list[str] = []
for req_id, finish_time in self._reqs_to_send.items():
if finish_time < 0:
# Request just finished, start timeout.
self._reqs_to_send[req_id] = now
elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT:
# Timeout exceed, clear the request blocks.
timed_out_requests.append(req_id)

for req_id in timed_out_requests:
# Skip communication with other ranks, but
if self.tp_rank == 0:
self._done_sending_count[req_id] += self.world_size
done_sending.add(req_id)
del self._reqs_to_send[req_id]
Comment on lines +815 to +828
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The timeout mechanism implemented here relies on time.perf_counter(), which is susceptible to system clock adjustments. If the system clock is adjusted backward, it could cause requests to timeout prematurely. Consider using a monotonic clock source that is not affected by system clock changes.

import time

# Use time.monotonic() instead of time.perf_counter()
now = time.monotonic()
timed_out_requests: list[str] = []
for req_id, finish_time in self._reqs_to_send.items():
    if finish_time < 0:
        # Request just finished, start timeout.
        self._reqs_to_send[req_id] = now
    elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT:
        # Timeout exceed, clear the request blocks.
        timed_out_requests.append(req_id)

Comment on lines +812 to +828
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to keep things simple and omit the TP optimization. I think we'll likely make this logic generic and move it outside of the connector impl anyhow (aggregating the finished events in TP case).

Dicts are ordered so we only need to peek the oldest entry.

Suggested change
# Handle timeout to avoid stranding blocks on remote.
now = time.monotonic()
timed_out_requests: list[str] = []
for req_id, finish_time in self._reqs_to_send.items():
if finish_time < 0:
# Request just finished, start timeout.
self._reqs_to_send[req_id] = now
elif now - finish_time >= envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT:
# Timeout exceed, clear the request blocks.
timed_out_requests.append(req_id)
for req_id in timed_out_requests:
# Skip communication with other ranks, but
if self.tp_rank == 0:
self._done_sending_count[req_id] += self.world_size
done_sending.add(req_id)
del self._reqs_to_send[req_id]
# Handle timeout to avoid stranding blocks on remote.
now = time.time()
while self._reqs_to_send:
req_id, expires = next(iter(self._reqs_to_send.items()))
if now < expires:
break
del self._reqs_to_send[req_id]
done_sending.add(req_id)


if self.world_size == 1:
return done_sending, done_recving

Expand Down Expand Up @@ -830,7 +858,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:

all_done_sending: set[str] = set()
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] == self.world_size:
if self._done_sending_count[req_id] >= self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)

Expand Down Expand Up @@ -860,6 +888,7 @@ def _get_new_notifs(self) -> set[str]:
tp_ratio):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id]
return notified_req_ids

def _pop_done_transfers(
Expand Down Expand Up @@ -894,7 +923,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.requests.items():
for req_id, meta in metadata.reqs_to_recv.items():
remote_engine_id = meta.remote_engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
Expand Down Expand Up @@ -939,6 +968,11 @@ def request_ready(_f: Future[Any],
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())

# Track the request that are waiting to be read and abort on timeout.
# Set to -1 so that timeout does not depend on model latency.
for req_id in metadata.reqs_to_send:
self._reqs_to_send[req_id] = -1
Comment on lines +971 to +974
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this would be needed per my other comments


def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
Expand Down Expand Up @@ -1043,8 +1077,7 @@ def _read_blocks(self, local_block_ids: list[int],

# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append(
(handle, time.perf_counter()))
self._recving_transfers[request_id].append((handle, time.monotonic()))

def _get_block_descs_ids(self,
engine_id: str,
Expand Down
10 changes: 9 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120


def get_default_cache_root():
Expand Down Expand Up @@ -955,7 +956,14 @@ def get_vllm_port() -> Optional[int]:
# generations on machines < 100 for compressed-tensors
# models
"VLLM_USE_NVFP4_CT_EMULATIONS":
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")))
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),

# Time (in seconds) after which the KV cache on the producer side is
# automatically cleared if no READ notification is received from the
# consumer. This is only applicable when using NixlConnector in a
# disaggregated decode-prefill setup.
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
}

# --8<-- [end:env-vars-definition]
Expand Down