-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Bugfix][Nixl] Fix full prefix cache hit bug #18632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b15f974
95408aa
61a2900
6bde0f1
2c3cb80
dd35648
f6ed8c4
c5546c3
fb844a5
9e435e0
354d775
45bd917
0c30192
cac5027
3eaea72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -221,15 +221,6 @@ def get_num_new_matched_tokens( | |||||||||||||||||
if count > 0: | ||||||||||||||||||
return count, True | ||||||||||||||||||
|
||||||||||||||||||
# NOTE: if count is 0 here, we have less than block_size | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now handled in |
||||||||||||||||||
# tokens to pull after subtracting the local prefix cache hit. | ||||||||||||||||||
# The remote only sends fully computed blocks, so there is | ||||||||||||||||||
# nothing to transfer but we still need to notify the | ||||||||||||||||||
# prefill worker so that the remote blocks are freed. | ||||||||||||||||||
if all(p in params for p in ("remote_engine_id", "remote_host", | ||||||||||||||||||
"remote_port")): | ||||||||||||||||||
self._reqs_need_recv[request.request_id] = (request, []) | ||||||||||||||||||
|
||||||||||||||||||
# No remote prefill for this request. | ||||||||||||||||||
return 0, False | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -247,9 +238,14 @@ def update_state_after_alloc(self, request: "Request", | |||||||||||||||||
if params.get("remote_block_ids"): | ||||||||||||||||||
if all(p in params for p in ("remote_engine_id", "remote_host", | ||||||||||||||||||
"remote_port")): | ||||||||||||||||||
# If remote_blocks and num_external_tokens = 0, we have | ||||||||||||||||||
# a full prefix cache hit on the D worker. We need to call | ||||||||||||||||||
# send_notif in _read_blocks to free the memory on the P. | ||||||||||||||||||
local_block_ids = (blocks.get_unhashed_block_ids() | ||||||||||||||||||
if num_external_tokens > 0 else []) | ||||||||||||||||||
# Get unhashed blocks to pull from remote. | ||||||||||||||||||
self._reqs_need_recv[request.request_id] = ( | ||||||||||||||||||
request, blocks.get_unhashed_block_ids()) | ||||||||||||||||||
request, local_block_ids) | ||||||||||||||||||
Comment on lines
+241
to
+248
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-redhat I'm still not sure that this part or the change to always call vllm/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Lines 215 to 222 in f203673
I can see that the other two fixes below in If you think it's better to have this logic in this method then we can remove it from the other one. But again I feel it's logically clearer to not call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that
This is all just evidence that putting a side effect into this function is not a good idea. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed those lines from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-redhat that makes sense, I agree about the pure function thing. I did also notice the fact that this could result in a double free on the P worker side in the case that it can't be scheduled, which isn't ideal (though I think would probably be harmless). But to me, thinking from the pov of a generic connector interface, it still feels a bit odd given the connector isn't offering any tokens. I guess we should very clearly document the semantics and expectations for the interface. A related quirk is that in the async load case, I think currently
njhill marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
else: | ||||||||||||||||||
logger.warning( | ||||||||||||||||||
"Got invalid KVTransferParams: %s. This " | ||||||||||||||||||
|
@@ -268,15 +264,6 @@ def build_connector_meta( | |||||||||||||||||
# Loop through scheduled reqs and convert to ReqMeta. | ||||||||||||||||||
for req_id, (req, block_ids) in self._reqs_need_recv.items(): | ||||||||||||||||||
assert req.kv_transfer_params is not None | ||||||||||||||||||
# For the case where there are no remote blocks to pull | ||||||||||||||||||
# (block_ids is empty), we don't need to schedule | ||||||||||||||||||
# an async read on the worker side. | ||||||||||||||||||
if not block_ids: | ||||||||||||||||||
logger.debug( | ||||||||||||||||||
"Skipping adding request %s to NixlConnectorMetadata, " | ||||||||||||||||||
"as there are no remote blocks to pull", req_id) | ||||||||||||||||||
continue | ||||||||||||||||||
|
||||||||||||||||||
meta.add_new_req( | ||||||||||||||||||
request_id=req_id, | ||||||||||||||||||
local_block_ids=block_ids, | ||||||||||||||||||
|
@@ -660,26 +647,26 @@ def add_remote_agent(self, | |||||||||||||||||
|
||||||||||||||||||
# Number of D TP workers reading from a single P TP worker. This is | ||||||||||||||||||
# 1 when P and D `--tensor-parallel-size` match. | ||||||||||||||||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \ | ||||||||||||||||||
"Local TP size must be divisible by remote TP size." | ||||||||||||||||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, ( | ||||||||||||||||||
"Local TP size must be divisible by remote TP size.") | ||||||||||||||||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] | ||||||||||||||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than" | ||||||||||||||||||
" prefill TP" | ||||||||||||||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" | ||||||||||||||||||
if self.use_mla: | ||||||||||||||||||
# With MLA the only difference is in the number of blocks. | ||||||||||||||||||
remote_block_size = nixl_agent_meta.block_len / ( | ||||||||||||||||||
remote_block_size = nixl_agent_meta.block_len // ( | ||||||||||||||||||
self.slot_size_bytes) | ||||||||||||||||||
assert self.block_len == nixl_agent_meta.block_len | ||||||||||||||||||
else: | ||||||||||||||||||
remote_block_size = nixl_agent_meta.block_len / ( | ||||||||||||||||||
remote_block_size = nixl_agent_meta.block_len // ( | ||||||||||||||||||
self.slot_size_bytes * tp_ratio) | ||||||||||||||||||
|
||||||||||||||||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \ | ||||||||||||||||||
"Remote P worker KV layer cache must be of shape [2, N, \ | ||||||||||||||||||
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." | ||||||||||||||||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( | ||||||||||||||||||
"Remote P worker KV layer cache must be of shape [2, N, " | ||||||||||||||||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
assert self.block_size == remote_block_size, "Remote P worker with \ | ||||||||||||||||||
different block size is not supported" | ||||||||||||||||||
assert self.block_size == remote_block_size, "Remote P worker with " | ||||||||||||||||||
"different block size is not supported" | ||||||||||||||||||
|
||||||||||||||||||
assert self.num_blocks >= nixl_agent_meta.num_blocks | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -712,9 +699,9 @@ def add_remote_agent(self, | |||||||||||||||||
# (addr, len, device id) | ||||||||||||||||||
blocks_data.append((addr, self.block_len, remote_tp_rank)) | ||||||||||||||||||
logger.debug( | ||||||||||||||||||
"Created %s blocks for dst engine %s with remote rank %s and " \ | ||||||||||||||||||
"local rank %s", | ||||||||||||||||||
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank) | ||||||||||||||||||
"Created %s blocks for dst engine %s with remote rank %s and " | ||||||||||||||||||
"local rank %s", len(blocks_data), engine_id, remote_tp_rank, | ||||||||||||||||||
self.tp_rank) | ||||||||||||||||||
|
||||||||||||||||||
# Register with NIXL. | ||||||||||||||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") | ||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.