-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
[KVConnector] Aggregate finished requests on the scheduler #19555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
njhill
merged 1 commit into
vllm-project:main
from
orozery:connector-metadata-worker-output
Jul 10, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,7 +9,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import traceback | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import weakref | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from concurrent.futures import Future, ThreadPoolExecutor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from concurrent.futures import CancelledError, Future, ThreadPoolExecutor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from enum import Enum, auto | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -111,10 +112,19 @@ def _init_executor(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.max_concurrent_batches > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Note: must use only 1 IO thread to keep dequeue sequence | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # from the response queue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # _async_aggregate_workers_output also assumes a single IO thread | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.io_thread_pool = ThreadPoolExecutor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_workers=1, thread_name_prefix="mp_exec_io") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.output_rank = self._get_output_rank() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.has_connector = self.vllm_config.kv_transfer_config is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Complete transfer tracker. Used by to track finished requests | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # [req_id -> n_finished_workers] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._recv_remaining_count = defaultdict[str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int](lambda: self.world_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._send_remaining_count = defaultdict[str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int](lambda: self.world_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def start_worker_monitor(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workers = self.workers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -155,13 +165,29 @@ def execute_model( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scheduler_output, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (output, ) = self.collective_rpc( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| non_block = self.max_concurrent_batches > 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.has_connector: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # get output only from a single worker (output_rank) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (output, ) = self.collective_rpc( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "execute_model", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args=(scheduler_output, ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unique_reply_rank=self.output_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| non_block=non_block, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # get output from all workers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = self.collective_rpc( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "execute_model", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args=(scheduler_output, ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unique_reply_rank=self.output_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| non_block=self.max_concurrent_batches > 1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| non_block=non_block, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # aggregate all workers output to a single output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if non_block: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self._async_aggregate_workers_output(outputs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self._aggregate_workers_output(outputs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def collective_rpc(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| method: Union[str, Callable], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -220,6 +246,80 @@ def get_response(w: WorkerProcHandle, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except TimeoutError as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise TimeoutError(f"RPC call to {method} timed out.") from e | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _aggregate_workers_output( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # aggregate finished_sending, finished_recving from all workers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finished_sending = set[str]() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finished_recving = set[str]() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for output in outputs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # update finished_sending | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for req_id in output.finished_sending or []: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| new_count = self._send_remaining_count[req_id] - 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if new_count == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # got response from all workers, report back to scheduler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finished_sending.add(req_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| del self._send_remaining_count[req_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._send_remaining_count[req_id] = new_count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # update finished_recving | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for req_id in output.finished_recving or []: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| new_count = self._recv_remaining_count[req_id] - 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if new_count == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # got response from all workers, report back to scheduler | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finished_recving.add(req_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| del self._recv_remaining_count[req_id] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._recv_remaining_count[req_id] = new_count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+253
to
+274
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # select output of the worker specified by output_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output = outputs[self.output_rank] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set the aggregated finished_sending / finished_recving | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if finished_sending: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output.finished_sending = finished_sending | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if finished_recving: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output.finished_recving = finished_recving | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _async_aggregate_workers_output( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, output_futures: list[Future[ModelRunnerOutput]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> (Future[ModelRunnerOutput]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Takes a list of futures and returns a single future which resolves | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| to the respective list of outputs.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result_future: Future[ModelRunnerOutput] = Future() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs: list[Optional[ModelRunnerOutput]] = [None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] * len(output_futures) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def make_callback(idx): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def callback(fut): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if result_future.done(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs[idx] = fut.result() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except CancelledError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result_future.cancel() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result_future.set_exception(e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # this check assumes io_thread_pool uses a single thread | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if all(outputs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result_future.set_result( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._aggregate_workers_output( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cast(list[ModelRunnerOutput], outputs))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return callback | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i, output_future in enumerate(output_futures): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_future.add_done_callback(make_callback(i)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result_future | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _ensure_worker_termination(worker_procs: list[BaseProcess]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Ensure that all worker processes are terminated. Assumes workers have | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import copy | ||
| import gc | ||
| import time | ||
| import weakref | ||
|
|
@@ -1234,8 +1233,6 @@ def _pool( | |
| hidden_states: torch.Tensor, | ||
| num_scheduled_tokens: int, | ||
| num_scheduled_tokens_np: np.ndarray, | ||
| finished_sending: Optional[set[str]], | ||
| finished_recving: Optional[set[str]], | ||
| ) -> ModelRunnerOutput: | ||
| assert self.input_batch.num_reqs ==\ | ||
| len(self.input_batch.pooling_params), \ | ||
|
|
@@ -1270,8 +1267,6 @@ def _pool( | |
| logprobs=None, | ||
| prompt_logprobs_dict={}, | ||
| pooler_output=pooler_output, | ||
| finished_sending=finished_sending, | ||
| finished_recving=finished_recving, | ||
| ) | ||
|
|
||
| @torch.inference_mode() | ||
|
|
@@ -1282,11 +1277,12 @@ def execute_model( | |
| ) -> Union[ModelRunnerOutput, IntermediateTensors]: | ||
| self._update_states(scheduler_output) | ||
| if not scheduler_output.total_num_scheduled_tokens: | ||
| if not has_kv_transfer_group(): | ||
| # Return empty ModelRunnerOutput if there's no work to do. | ||
| return EMPTY_MODEL_RUNNER_OUTPUT | ||
| if has_kv_transfer_group(): | ||
| with set_forward_context(None, self.vllm_config): | ||
| self.maybe_setup_kv_connector(scheduler_output) | ||
|
|
||
|
||
| return self.kv_connector_no_forward(scheduler_output) | ||
| # Return empty ModelRunnerOutput if there's no work to do. | ||
| return EMPTY_MODEL_RUNNER_OUTPUT | ||
|
|
||
| # Prepare the decoder inputs. | ||
| (attn_metadata, attention_cuda_graphs, logits_indices, | ||
|
|
@@ -1379,8 +1375,6 @@ def execute_model( | |
| ) | ||
|
|
||
| self.maybe_wait_for_kv_save() | ||
| finished_sending, finished_recving = ( | ||
| self.get_finished_kv_transfers(scheduler_output)) | ||
|
|
||
| if self.use_aux_hidden_state_outputs: | ||
| hidden_states, aux_hidden_states = model_output | ||
|
|
@@ -1406,8 +1400,7 @@ def execute_model( | |
| else: | ||
| if self.input_batch.pooling_params: | ||
| return self._pool(hidden_states, num_scheduled_tokens, | ||
| num_scheduled_tokens_np, finished_sending, | ||
| finished_recving) | ||
| num_scheduled_tokens_np) | ||
|
|
||
| sample_hidden_states = hidden_states[logits_indices] | ||
| logits = self.model.compute_logits(sample_hidden_states, None) | ||
|
|
@@ -1560,8 +1553,6 @@ def execute_model( | |
| logprobs=logprobs_lists, | ||
| prompt_logprobs_dict=prompt_logprobs_dict, | ||
| pooler_output=[], | ||
| finished_sending=finished_sending, | ||
| finished_recving=finished_recving, | ||
| num_nans_in_logits=num_nans_in_logits, | ||
| ) | ||
|
|
||
|
|
@@ -1686,22 +1677,6 @@ def propose_draft_token_ids( | |
| spec_token_ids = draft_token_ids.tolist() | ||
| return spec_token_ids | ||
|
|
||
| def kv_connector_no_forward( | ||
| self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: | ||
| # KV send/recv even if no work to do. | ||
| with set_forward_context(None, self.vllm_config): | ||
| self.maybe_setup_kv_connector(scheduler_output) | ||
| finished_sending, finished_recving = ( | ||
| self.get_finished_kv_transfers(scheduler_output)) | ||
|
|
||
| if not finished_sending and not finished_recving: | ||
| return EMPTY_MODEL_RUNNER_OUTPUT | ||
|
|
||
| output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) | ||
| output.finished_sending = finished_sending | ||
| output.finished_recving = finished_recving | ||
| return output | ||
|
|
||
| @staticmethod | ||
| def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): | ||
| # Update KVConnector with the KVConnector metadata forward(). | ||
|
|
@@ -1723,15 +1698,6 @@ def maybe_wait_for_kv_save() -> None: | |
| if has_kv_transfer_group(): | ||
| get_kv_transfer_group().wait_for_save() | ||
|
|
||
| @staticmethod | ||
| def get_finished_kv_transfers( | ||
| scheduler_output: "SchedulerOutput", | ||
| ) -> tuple[Optional[set[str]], Optional[set[str]]]: | ||
| if has_kv_transfer_group(): | ||
| return get_kv_transfer_group().get_finished( | ||
| scheduler_output.finished_req_ids) | ||
| return None, None | ||
|
|
||
| def propose_ngram_draft_token_ids( | ||
| self, | ||
| sampled_token_ids: list[list[int]], | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.