@@ -1043,48 +1043,13 @@ def execute_model(
10431043 intermediate_tensors : Optional [IntermediateTensors ] = None ,
10441044 ) -> Union [ModelRunnerOutput , IntermediateTensors ]:
10451045
1046- def maybe_setup_kv_connector ():
1047- # Update KVConnector with the KVConnector metadata forward().
1048- if has_kv_transfer_group ():
1049- kv_connector = get_kv_transfer_group ()
1050- assert isinstance (kv_connector , KVConnectorBase_V1 )
1051- assert scheduler_output .kv_connector_metadata is not None
1052- kv_connector .bind_connector_metadata (
1053- scheduler_output .kv_connector_metadata )
1054-
1055- # Background KV cache transfers happen here.
1056- # These transfers are designed to be async and the requests
1057- # involved may be disjoint from the running requests.
1058- # Do this here to save a collective_rpc.
1059- kv_connector .start_load_kv (get_forward_context ())
1060-
1061- def maybe_wait_for_save ():
1062- if has_kv_transfer_group ():
1063- kv_connector = get_kv_transfer_group ()
1064- kv_connector .wait_for_save ()
1065-
1066- def maybe_get_finished () -> tuple [set [str ], set [str ]]:
1067- if has_kv_transfer_group ():
1068- kv_connector = get_kv_transfer_group ()
1069- return kv_connector .get_finished ()
1070- return set (), set ()
1071-
10721046 self ._update_states (scheduler_output )
10731047 if not scheduler_output .total_num_scheduled_tokens :
1074- # KV send/recv even if no work to do.
1075- with set_forward_context (None , self .vllm_config ):
1076- maybe_setup_kv_connector ()
1077- maybe_wait_for_save ()
1078- finished_sending , finished_recving = maybe_get_finished ()
1079-
1080- # Return empty ModelRunnerOutput if there's no work to do.
1081- output = EMPTY_MODEL_RUNNER_OUTPUT
1048+ if not has_kv_transfer_group ():
1049+ # Return empty ModelRunnerOutput if there's no work to do.
1050+ return EMPTY_MODEL_RUNNER_OUTPUT
10821051
1083- if len (finished_sending ) > 0 or len (finished_recving ) > 0 :
1084- output = copy .deepcopy (EMPTY_MODEL_RUNNER_OUTPUT )
1085- output .finished_sending = finished_sending
1086- output .finished_recving = finished_recving
1087- return output
1052+ return self .kv_connector_no_forward (scheduler_output )
10881053
10891054 # Prepare the decoder inputs.
10901055 attn_metadata , logits_indices , spec_decode_metadata = (
@@ -1161,7 +1126,7 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
11611126 with set_forward_context (attn_metadata ,
11621127 self .vllm_config ,
11631128 num_tokens = num_input_tokens ):
1164- maybe_setup_kv_connector ()
1129+ self . maybe_setup_kv_connector (scheduler_output )
11651130
11661131 model_output = self .model (
11671132 input_ids = input_ids ,
@@ -1170,8 +1135,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
11701135 inputs_embeds = inputs_embeds ,
11711136 )
11721137
1173- maybe_wait_for_save ()
1174- finished_sending , finished_recving = maybe_get_finished ()
1138+ self .maybe_wait_for_kv_save ()
1139+ finished_sending , finished_recving = (
1140+ self .get_finished_kv_transfers ())
11751141
11761142 if self .use_aux_hidden_state_outputs :
11771143 hidden_states , aux_hidden_states = model_output
@@ -1361,6 +1327,50 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
13611327 finished_recving = finished_recving ,
13621328 )
13631329
1330+ def kv_connector_no_forward (
1331+ self , scheduler_output : SchedulerOutput ) -> ModelRunnerOutput :
1332+ # KV send/recv even if no work to do.
1333+ with set_forward_context (None , self .vllm_config ):
1334+ self .maybe_setup_kv_connector (scheduler_output )
1335+ finished_sending , finished_recving = (
1336+ self .get_finished_kv_transfers ())
1337+
1338+ if not finished_sending and not finished_recving :
1339+ return EMPTY_MODEL_RUNNER_OUTPUT
1340+
1341+ output = copy .copy (EMPTY_MODEL_RUNNER_OUTPUT )
1342+ output .finished_sending = finished_sending
1343+ output .finished_recving = finished_recving
1344+ return output
1345+
1346+ @staticmethod
1347+ def maybe_setup_kv_connector (scheduler_output : SchedulerOutput ):
1348+ # Update KVConnector with the KVConnector metadata forward().
1349+ if has_kv_transfer_group ():
1350+ kv_connector = get_kv_transfer_group ()
1351+ assert isinstance (kv_connector , KVConnectorBase_V1 )
1352+ assert scheduler_output .kv_connector_metadata is not None
1353+ kv_connector .bind_connector_metadata (
1354+ scheduler_output .kv_connector_metadata )
1355+
1356+ # Background KV cache transfers happen here.
1357+ # These transfers are designed to be async and the requests
1358+ # involved may be disjoint from the running requests.
1359+ # Do this here to save a collective_rpc.
1360+ kv_connector .start_load_kv (get_forward_context ())
1361+
1362+ @staticmethod
1363+ def maybe_wait_for_kv_save () -> None :
1364+ if has_kv_transfer_group ():
1365+ get_kv_transfer_group ().wait_for_save ()
1366+
1367+ @staticmethod
1368+ def get_finished_kv_transfers (
1369+ ) -> tuple [Optional [set [str ]], Optional [set [str ]]]:
1370+ if has_kv_transfer_group ():
1371+ return get_kv_transfer_group ().get_finished ()
1372+ return None , None
1373+
13641374 def generate_draft_token_ids (
13651375 self ,
13661376 sampled_token_ids : list [list [int ]],
0 commit comments