Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def get_finished(
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
finished generating tokens on the worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.

Returns:
ids of requests that have finished asynchronous transfer
Expand Down
65 changes: 4 additions & 61 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}

# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self._done_recving_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0)
self._done_sending_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0)

# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
# Background thread for initializing new NIXL handshakes.
Expand Down Expand Up @@ -821,15 +813,9 @@ def add_remote_agent(self,

def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving.

In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
Expand All @@ -849,50 +835,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
del self._reqs_to_send[req_id]
done_sending.add(req_id)

if self.world_size == 1:
return done_sending, done_recving

# Rank 0: get finished from all other ranks.
if self.tp_rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving:
self._done_recving_count[req_id] += 1

# Keep track of how many other ranks have finished.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
self._done_recving_count[req_id] += 1
else:
self._done_sending_count[req_id] += 1

# Return ids that finished on all ranks to the scheduler.
all_done_recving: set[str] = set()
for req_id in list(self._done_recving_count.keys()):
if self._done_recving_count[req_id] == self.world_size:
del self._done_recving_count[req_id]
all_done_recving.add(req_id)

all_done_sending: set[str] = set()
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] >= self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)

return all_done_sending, all_done_recving

# Ranks 1 to N-1: send finished ids to Rank 0.
else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)

# Unused as only Rank 0 results are sent to scheduler.
return done_sending, done_recving
return done_sending, done_recving

def _get_new_notifs(self) -> set[str]:
"""
Expand Down
110 changes: 105 additions & 5 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import time
import traceback
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from collections import defaultdict
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
Expand Down Expand Up @@ -111,10 +112,19 @@ def _init_executor(self) -> None:
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")

self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None

# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)

def start_worker_monitor(self):
workers = self.workers
Expand Down Expand Up @@ -155,13 +165,29 @@ def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc(
non_block = self.max_concurrent_batches > 1

if not self.has_connector:
# get output only from a single worker (output_rank)
(output, ) = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output

# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches > 1,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output

# aggregate all workers output to a single output
if non_block:
return self._async_aggregate_workers_output(outputs)
return self._aggregate_workers_output(outputs)

def collective_rpc(self,
method: Union[str, Callable],
Expand Down Expand Up @@ -220,6 +246,80 @@ def get_response(w: WorkerProcHandle,
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e

def _aggregate_workers_output(
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers

finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count

# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
Comment on lines +253 to +274
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count
# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
def update_finished_set(
req_ids: list[str], remaining_count_dict: dict[str, int], finished_set: set[str]
) -> None:
for req_id in req_ids or []:
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending, self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving, self._recv_remaining_count, finished_recving)


# select output of the worker specified by output_rank
output = outputs[self.output_rank]

# set the aggregated finished_sending / finished_recving
if finished_sending:
output.finished_sending = finished_sending
if finished_recving:
output.finished_recving = finished_recving

return output

def _async_aggregate_workers_output(
self, output_futures: list[Future[ModelRunnerOutput]]
) -> (Future[ModelRunnerOutput]):
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()

outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)

def make_callback(idx):

def callback(fut):
if result_future.done():
return

try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)

# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self._aggregate_workers_output(
cast(list[ModelRunnerOutput], outputs)))

return callback

for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))

return result_future

@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have
Expand Down
46 changes: 6 additions & 40 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
import gc
import time
import weakref
Expand Down Expand Up @@ -1234,8 +1233,6 @@ def _pool(
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
Expand Down Expand Up @@ -1270,8 +1267,6 @@ def _pool(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
)

@torch.inference_mode()
Expand All @@ -1282,11 +1277,12 @@ def execute_model(
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if has_kv_transfer_group():
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing a call to clear_connector_metadata in this case (also before this change).

return self.kv_connector_no_forward(scheduler_output)
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT

# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
Expand Down Expand Up @@ -1379,8 +1375,6 @@ def execute_model(
)

self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))

if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
Expand All @@ -1406,8 +1400,7 @@ def execute_model(
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
num_scheduled_tokens_np)

sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
Expand Down Expand Up @@ -1560,8 +1553,6 @@ def execute_model(
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
)

Expand Down Expand Up @@ -1686,22 +1677,6 @@ def propose_draft_token_ids(
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids

def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))

if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT

output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output

@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
Expand All @@ -1723,15 +1698,6 @@ def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()

@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None

def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
Expand Down
Loading