Skip to content

[P/D][Bugfix]: Fix the metadata corruption issue in Nixl when TP > 1. #19341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,20 @@ def add_remote_agent(self,
# TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, ()):
return
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
assert self._tp_size[self.engine_id] % nixl_agent_meta.tp_size == 0, (
"Local TP size must be divisible by remote TP size.")
tp_ratio = self._tp_size[self.engine_id] // nixl_agent_meta.tp_size
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"

if remote_tp_rank != self.tp_rank // tp_ratio:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than an early exit here, which is equivalent to skipping descs at L727, I'd prefer fixing the num_blocks issue for once and still execute the other asserts (those are extra correctness checks that won't hurt to have).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I'll look into the solution you proposed.

Just to add/clarify:

Here, it only skips the first handshake execution. During the second handshake execution, it will not be skipped.

# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
metadata = handshake(path, 0)

The first handshake execution is solely for obtaining the remote tp_size, so it can be safely skipped.

path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_ = handshake(path, p_remote_rank)

The second handshake execution, however, will not be skipped—all checks will be performed as expected

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah absolutely, this was also the original intent: handshake with rank0 is to get tp_size.
But skipping a bunch of extra asserts on top won't get us to solving the core issue.

# Only register remote agents that this local rank pulls from.
logger.debug(
"Skipping registration of remote agent %s with rank %s "
"as it is not the one this local rank %s pulls from.",
engine_id, remote_tp_rank, self.tp_rank)
return
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
else:
Expand All @@ -676,13 +689,6 @@ def add_remote_agent(self,
self._remote_agents[engine_id][
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)

# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
"Local TP size must be divisible by remote TP size.")
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
if self.use_mla:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len // (
Expand Down