Skip to content

Commit c2f2e77

Browse files
Merge pull request vllm-project#10 from njhill/streamline-runner
Move new GPUModelRunner methods out of `execute_model` method
2 parents 3050565 + e673bdd commit c2f2e77

File tree

1 file changed

+52
-42
lines changed

1 file changed

+52
-42
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,48 +1076,13 @@ def execute_model(
10761076
intermediate_tensors: Optional[IntermediateTensors] = None,
10771077
) -> Union[ModelRunnerOutput, IntermediateTensors]:
10781078

1079-
def maybe_setup_kv_connector():
1080-
# Update KVConnector with the KVConnector metadata forward().
1081-
if has_kv_transfer_group():
1082-
kv_connector = get_kv_transfer_group()
1083-
assert isinstance(kv_connector, KVConnectorBase_V1)
1084-
assert scheduler_output.kv_connector_metadata is not None
1085-
kv_connector.bind_connector_metadata(
1086-
scheduler_output.kv_connector_metadata)
1087-
1088-
# Background KV cache transfers happen here.
1089-
# These transfers are designed to be async and the requests
1090-
# involved may be disjoint from the running requests.
1091-
# Do this here to save a collective_rpc.
1092-
kv_connector.start_load_kv(get_forward_context())
1093-
1094-
def maybe_wait_for_save():
1095-
if has_kv_transfer_group():
1096-
kv_connector = get_kv_transfer_group()
1097-
kv_connector.wait_for_save()
1098-
1099-
def maybe_get_finished() -> tuple[set[str], set[str]]:
1100-
if has_kv_transfer_group():
1101-
kv_connector = get_kv_transfer_group()
1102-
return kv_connector.get_finished()
1103-
return set(), set()
1104-
11051079
self._update_states(scheduler_output)
11061080
if not scheduler_output.total_num_scheduled_tokens:
1107-
# KV send/recv even if no work to do.
1108-
with set_forward_context(None, self.vllm_config):
1109-
maybe_setup_kv_connector()
1110-
maybe_wait_for_save()
1111-
finished_sending, finished_recving = maybe_get_finished()
1112-
1113-
# Return empty ModelRunnerOutput if there's no work to do.
1114-
output = EMPTY_MODEL_RUNNER_OUTPUT
1081+
if not has_kv_transfer_group():
1082+
# Return empty ModelRunnerOutput if there's no work to do.
1083+
return EMPTY_MODEL_RUNNER_OUTPUT
11151084

1116-
if len(finished_sending) > 0 or len(finished_recving) > 0:
1117-
output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
1118-
output.finished_sending = finished_sending
1119-
output.finished_recving = finished_recving
1120-
return output
1085+
return self.kv_connector_no_forward(scheduler_output)
11211086

11221087
# Prepare the decoder inputs.
11231088
attn_metadata, logits_indices, spec_decode_metadata = (
@@ -1194,7 +1159,7 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
11941159
with set_forward_context(attn_metadata,
11951160
self.vllm_config,
11961161
num_tokens=num_input_tokens):
1197-
maybe_setup_kv_connector()
1162+
self.maybe_setup_kv_connector(scheduler_output)
11981163

11991164
model_output = self.model(
12001165
input_ids=input_ids,
@@ -1203,8 +1168,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
12031168
inputs_embeds=inputs_embeds,
12041169
)
12051170

1206-
maybe_wait_for_save()
1207-
finished_sending, finished_recving = maybe_get_finished()
1171+
self.maybe_wait_for_kv_save()
1172+
finished_sending, finished_recving = (
1173+
self.get_finished_kv_transfers())
12081174

12091175
if self.use_aux_hidden_state_outputs:
12101176
hidden_states, aux_hidden_states = model_output
@@ -1394,6 +1360,50 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
13941360
finished_recving=finished_recving,
13951361
)
13961362

1363+
def kv_connector_no_forward(
1364+
self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
1365+
# KV send/recv even if no work to do.
1366+
with set_forward_context(None, self.vllm_config):
1367+
self.maybe_setup_kv_connector(scheduler_output)
1368+
finished_sending, finished_recving = (
1369+
self.get_finished_kv_transfers())
1370+
1371+
if not finished_sending and not finished_recving:
1372+
return EMPTY_MODEL_RUNNER_OUTPUT
1373+
1374+
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
1375+
output.finished_sending = finished_sending
1376+
output.finished_recving = finished_recving
1377+
return output
1378+
1379+
@staticmethod
1380+
def maybe_setup_kv_connector(scheduler_output: SchedulerOutput):
1381+
# Update KVConnector with the KVConnector metadata forward().
1382+
if has_kv_transfer_group():
1383+
kv_connector = get_kv_transfer_group()
1384+
assert isinstance(kv_connector, KVConnectorBase_V1)
1385+
assert scheduler_output.kv_connector_metadata is not None
1386+
kv_connector.bind_connector_metadata(
1387+
scheduler_output.kv_connector_metadata)
1388+
1389+
# Background KV cache transfers happen here.
1390+
# These transfers are designed to be async and the requests
1391+
# involved may be disjoint from the running requests.
1392+
# Do this here to save a collective_rpc.
1393+
kv_connector.start_load_kv(get_forward_context())
1394+
1395+
@staticmethod
1396+
def maybe_wait_for_kv_save() -> None:
1397+
if has_kv_transfer_group():
1398+
get_kv_transfer_group().wait_for_save()
1399+
1400+
@staticmethod
1401+
def get_finished_kv_transfers(
1402+
) -> tuple[Optional[set[str]], Optional[set[str]]]:
1403+
if has_kv_transfer_group():
1404+
return get_kv_transfer_group().get_finished()
1405+
return None, None
1406+
13971407
def generate_draft_token_ids(
13981408
self,
13991409
sampled_token_ids: list[list[int]],

0 commit comments

Comments
 (0)