Skip to content

[Bugfix][Nixl] Fix full prefix cache hit bug #18632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 47 additions & 29 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"

Expand All @@ -32,7 +33,7 @@ def __init__(self, config: VllmConfig, role):
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}_events.log"
) + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
Expand All @@ -52,10 +53,19 @@ def __getattribute__(self, name):

def wrapper(*args, **kwargs):
self.call_record[name] += 1

# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(f"num_blocks={len(arg.blocks)}")

# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(name + "\n")
f.write(' '.join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
Expand Down Expand Up @@ -162,15 +172,23 @@ def test_multi_shared_storage_connector_consistency():
f"{storage_1_path} and {storage_2_path}")

events = get_connector_events()
# get_num_new_matched_tokens will be called on each connector in turn.
# neither of them have hits so update_state_after_alloc won't be called.
assert events["storage1"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
# get_num_new_matched_tokens and update_state_after_alloc will be called
# on each connector in turn.
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]
assert events["storage1-WORKER"][:5] == [
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
'wait_for_layer_load', 'save_kv_layer'
]
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]
assert events["storage2"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
assert events["storage2-WORKER"][:5] == [
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
'wait_for_layer_load', 'save_kv_layer'
]

# Reset prefix cache or else we'll just get the tokens back from there.
Expand All @@ -182,16 +200,16 @@ def test_multi_shared_storage_connector_consistency():

events = get_connector_events()
# get_num_new_matched_tokens will return new tokens from the first
# connector so update_state_after_alloc will be called once blocks
# are allocated for the first connector.
# get_num_new_matched_tokens *won't* be called on the second connector
# in this case.
assert events["storage1"][:4] == [
'get_num_new_matched_tokens', 'update_state_after_alloc',
'build_connector_meta', 'bind_connector_metadata'
# connector so update_state_after_alloc will be with allocated blocks
# on that one but with zero blocks for others (first nonzero match is
# chosen).
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
]
assert events["storage2"][:2] == [
'build_connector_meta', 'bind_connector_metadata'
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]

# Delete storage1 connector state
Expand All @@ -205,17 +223,17 @@ def test_multi_shared_storage_connector_consistency():
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)

events = get_connector_events()
# get_num_new_matched_tokens will be called for the first connector but it
# won't have a hit so update_state_after_alloc won't be called.
# get_num_new_matched_tokens will also be called on the second connector,
# but it should have a hit so update_state_after_alloc will be called.
assert events["storage1"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
# get_num_new_matched_tokens will be called for both connectors but will
# return 0 from the first connector, but the second connector should have
# a hit, so update_state_after_alloc will only be called with allocated
# blocks for the second connector.
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]
assert events["storage2"][:4] == [
'get_num_new_matched_tokens', 'update_state_after_alloc',
'build_connector_meta', 'bind_connector_metadata'
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
]

# Clean up
Expand Down
41 changes: 26 additions & 15 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request

logger = init_logger(__name__)
Expand Down Expand Up @@ -51,8 +51,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self._connectors.append(
KVConnectorFactory.create_connector_v1(temp_config, role))

# A mapping from request id to the connector that is assigned to it.
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
self._requests_to_connector: dict[str, int] = {}

# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
Expand Down Expand Up @@ -136,25 +137,31 @@ def get_num_new_matched_tokens(
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
for c in self._connectors:
to_return = (0, False)
for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens)
# The first connector that has new matched tokens will be assigned
# to this request.
if toks > 0:
self._requests_to_connector[request.request_id] = c
return toks, load_async
return 0, False
if to_return[0] == 0 and toks > 0:
self._requests_to_connector[request.request_id] = i
to_return = (toks, load_async)
return to_return

def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
# If the request is not assigned to any connector, we do nothing.
if request.request_id not in self._requests_to_connector:
return
# We assume that the request is assigned to only one connector.
c = self._requests_to_connector.pop(request.request_id)
c.update_state_after_alloc(request, blocks, num_external_tokens)
chosen_connector = self._requests_to_connector.get(
request.request_id, -1)
for i, c in enumerate(self._connectors):
if i == chosen_connector:
# Forward call to the chosen connector (if any).
c.update_state_after_alloc(request, blocks,
num_external_tokens)
else:
# Call with empty blocks for other connectors.
c.update_state_after_alloc(request,
KVCacheBlocks.create_empty(), 0)

def build_connector_meta(
self,
Expand All @@ -170,7 +177,7 @@ def build_connector_meta(
def request_finished(
self,
request: "Request",
blocks: "KVCacheBlocks",
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
async_saves = 0
kv_txfer_params = None
Expand All @@ -187,4 +194,8 @@ def request_finished(
kv_txfer_params = txfer_params
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1

# Clean up other state for this request.
self._requests_to_connector.pop(request.request_id, None)

return async_saves > 0, kv_txfer_params
53 changes: 20 additions & 33 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,6 @@ def get_num_new_matched_tokens(
if count > 0:
return count, True

# NOTE: if count is 0 here, we have less than block_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now handled in update_state_after_alloc

# tokens to pull after subtracting the local prefix cache hit.
# The remote only sends fully computed blocks, so there is
# nothing to transfer but we still need to notify the
# prefill worker so that the remote blocks are freed.
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
self._reqs_need_recv[request.request_id] = (request, [])

# No remote prefill for this request.
return 0, False

Expand All @@ -247,9 +238,14 @@ def update_state_after_alloc(self, request: "Request",
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
request, local_block_ids)
Comment on lines +241 to +248
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat I'm still not sure that this part or the change to always call update_state_after_alloc is needed. I'd already added logic for this case in get_num_new_matched_tokens above:

# NOTE: if count is 0 here, we have less than block_size
# tokens to pull after subtracting the local prefix cache hit.
# The remote only sends fully computed blocks, so there is
# nothing to transfer but we still need to notify the
# prefill worker so that the remote blocks are freed.
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
self._reqs_need_recv[request.request_id] = (request, [])

I can see that the other two fixes below in build_connector_meta and _read_blocks are of course needed though.

If you think it's better to have this logic in this method then we can remove it from the other one. But again I feel it's logically clearer to not call update_state_after_alloc if 0 was returned from get_num_new_matched_tokens.

Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that get_num_new_matched_tokens should be a pure function. Adding a side effect to it is surprising given the name of the method and the fact that we will have different behavior depending on what happens if the request is or is not able to be scheduled. This issue is actually causing a bug right now.

  • If allocate_slots returns None, the request will remain in the waiting queue. this will cause us to add the requests to reqs_need_recv more than one and as a result we will call read_blocks twice which will do a double free on the P worker side. Similarly this will happen if the request is preempted (it will get re-added to waiting). This is because we are not properly updating the request to have do_remote_prefill=False when it is added to reqs_need_recv from the get_num_new_matched_tokens function.

This is all just evidence that putting a side effect into this function is not a good idea. The update_state_after_alloc is where we should handle everything related to reqs_need_recv so we have a single place where all the logic is handled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed those lines from get_num_new_matched_tokens

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat that makes sense, I agree about the pure function thing. I did also notice the fact that this could result in a double free on the P worker side in the case that it can't be scheduled, which isn't ideal (though I think would probably be harmless).

But to me, thinking from the pov of a generic connector interface, it still feels a bit odd given the connector isn't offering any tokens. I guess we should very clearly document the semantics and expectations for the interface.

A related quirk is that in the async load case, I think currently update_state_after_alloc will be called twice for a request (a second time once the request moves out of WAITING_FOR_REMOTE_KVS).

else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
Expand All @@ -268,15 +264,6 @@ def build_connector_meta(
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
if not block_ids:
logger.debug(
"Skipping adding request %s to NixlConnectorMetadata, "
"as there are no remote blocks to pull", req_id)
continue

meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
Expand Down Expand Up @@ -660,26 +647,26 @@ def add_remote_agent(self,

# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
"Local TP size must be divisible by remote TP size."
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
"Local TP size must be divisible by remote TP size.")
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
assert tp_ratio > 0, "Decode TP cannot be smaller than"
" prefill TP"
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
if self.use_mla:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len / (
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes)
assert self.block_len == nixl_agent_meta.block_len
else:
remote_block_size = nixl_agent_meta.block_len / (
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)

assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
"Remote P worker KV layer cache must be of shape [2, N, \
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)

assert self.block_size == remote_block_size, "Remote P worker with \
different block size is not supported"
assert self.block_size == remote_block_size, "Remote P worker with "
"different block size is not supported"

assert self.num_blocks >= nixl_agent_meta.num_blocks

Expand Down Expand Up @@ -712,9 +699,9 @@ def add_remote_agent(self,
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and " \
"local rank %s",
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,11 @@ def schedule(self) -> SchedulerOutput:
# The request cannot be scheduled.
break

# KVConnector: update internal state after allocation.
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if num_external_computed_tokens:
assert self.connector is not None
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
Expand Down Expand Up @@ -841,7 +841,7 @@ def update_from_output(
}

finished_req_ids = self.finished_req_ids_dict
if finished_req_ids is not None:
if finished_req_ids:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():
Expand Down