Skip to content

Commit e673bdd

Browse files
committed
Move new GPUModelRunner methods out of execute_model method
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 5c3fc88 commit e673bdd

File tree

1 file changed

+52
-42
lines changed

1 file changed

+52
-42
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,48 +1043,13 @@ def execute_model(
10431043
intermediate_tensors: Optional[IntermediateTensors] = None,
10441044
) -> Union[ModelRunnerOutput, IntermediateTensors]:
10451045

1046-
def maybe_setup_kv_connector():
1047-
# Update KVConnector with the KVConnector metadata forward().
1048-
if has_kv_transfer_group():
1049-
kv_connector = get_kv_transfer_group()
1050-
assert isinstance(kv_connector, KVConnectorBase_V1)
1051-
assert scheduler_output.kv_connector_metadata is not None
1052-
kv_connector.bind_connector_metadata(
1053-
scheduler_output.kv_connector_metadata)
1054-
1055-
# Background KV cache transfers happen here.
1056-
# These transfers are designed to be async and the requests
1057-
# involved may be disjoint from the running requests.
1058-
# Do this here to save a collective_rpc.
1059-
kv_connector.start_load_kv(get_forward_context())
1060-
1061-
def maybe_wait_for_save():
1062-
if has_kv_transfer_group():
1063-
kv_connector = get_kv_transfer_group()
1064-
kv_connector.wait_for_save()
1065-
1066-
def maybe_get_finished() -> tuple[set[str], set[str]]:
1067-
if has_kv_transfer_group():
1068-
kv_connector = get_kv_transfer_group()
1069-
return kv_connector.get_finished()
1070-
return set(), set()
1071-
10721046
self._update_states(scheduler_output)
10731047
if not scheduler_output.total_num_scheduled_tokens:
1074-
# KV send/recv even if no work to do.
1075-
with set_forward_context(None, self.vllm_config):
1076-
maybe_setup_kv_connector()
1077-
maybe_wait_for_save()
1078-
finished_sending, finished_recving = maybe_get_finished()
1079-
1080-
# Return empty ModelRunnerOutput if there's no work to do.
1081-
output = EMPTY_MODEL_RUNNER_OUTPUT
1048+
if not has_kv_transfer_group():
1049+
# Return empty ModelRunnerOutput if there's no work to do.
1050+
return EMPTY_MODEL_RUNNER_OUTPUT
10821051

1083-
if len(finished_sending) > 0 or len(finished_recving) > 0:
1084-
output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
1085-
output.finished_sending = finished_sending
1086-
output.finished_recving = finished_recving
1087-
return output
1052+
return self.kv_connector_no_forward(scheduler_output)
10881053

10891054
# Prepare the decoder inputs.
10901055
attn_metadata, logits_indices, spec_decode_metadata = (
@@ -1161,7 +1126,7 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
11611126
with set_forward_context(attn_metadata,
11621127
self.vllm_config,
11631128
num_tokens=num_input_tokens):
1164-
maybe_setup_kv_connector()
1129+
self.maybe_setup_kv_connector(scheduler_output)
11651130

11661131
model_output = self.model(
11671132
input_ids=input_ids,
@@ -1170,8 +1135,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
11701135
inputs_embeds=inputs_embeds,
11711136
)
11721137

1173-
maybe_wait_for_save()
1174-
finished_sending, finished_recving = maybe_get_finished()
1138+
self.maybe_wait_for_kv_save()
1139+
finished_sending, finished_recving = (
1140+
self.get_finished_kv_transfers())
11751141

11761142
if self.use_aux_hidden_state_outputs:
11771143
hidden_states, aux_hidden_states = model_output
@@ -1361,6 +1327,50 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
13611327
finished_recving=finished_recving,
13621328
)
13631329

1330+
def kv_connector_no_forward(
1331+
self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
1332+
# KV send/recv even if no work to do.
1333+
with set_forward_context(None, self.vllm_config):
1334+
self.maybe_setup_kv_connector(scheduler_output)
1335+
finished_sending, finished_recving = (
1336+
self.get_finished_kv_transfers())
1337+
1338+
if not finished_sending and not finished_recving:
1339+
return EMPTY_MODEL_RUNNER_OUTPUT
1340+
1341+
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
1342+
output.finished_sending = finished_sending
1343+
output.finished_recving = finished_recving
1344+
return output
1345+
1346+
@staticmethod
1347+
def maybe_setup_kv_connector(scheduler_output: SchedulerOutput):
1348+
# Update KVConnector with the KVConnector metadata forward().
1349+
if has_kv_transfer_group():
1350+
kv_connector = get_kv_transfer_group()
1351+
assert isinstance(kv_connector, KVConnectorBase_V1)
1352+
assert scheduler_output.kv_connector_metadata is not None
1353+
kv_connector.bind_connector_metadata(
1354+
scheduler_output.kv_connector_metadata)
1355+
1356+
# Background KV cache transfers happen here.
1357+
# These transfers are designed to be async and the requests
1358+
# involved may be disjoint from the running requests.
1359+
# Do this here to save a collective_rpc.
1360+
kv_connector.start_load_kv(get_forward_context())
1361+
1362+
@staticmethod
1363+
def maybe_wait_for_kv_save() -> None:
1364+
if has_kv_transfer_group():
1365+
get_kv_transfer_group().wait_for_save()
1366+
1367+
@staticmethod
1368+
def get_finished_kv_transfers(
1369+
) -> tuple[Optional[set[str]], Optional[set[str]]]:
1370+
if has_kv_transfer_group():
1371+
return get_kv_transfer_group().get_finished()
1372+
return None, None
1373+
13641374
def generate_draft_token_ids(
13651375
self,
13661376
sampled_token_ids: list[list[int]],

0 commit comments

Comments
 (0)