@@ -1076,48 +1076,13 @@ def execute_model(
10761076 intermediate_tensors : Optional [IntermediateTensors ] = None ,
10771077 ) -> Union [ModelRunnerOutput , IntermediateTensors ]:
10781078
1079- def maybe_setup_kv_connector ():
1080- # Update KVConnector with the KVConnector metadata forward().
1081- if has_kv_transfer_group ():
1082- kv_connector = get_kv_transfer_group ()
1083- assert isinstance (kv_connector , KVConnectorBase_V1 )
1084- assert scheduler_output .kv_connector_metadata is not None
1085- kv_connector .bind_connector_metadata (
1086- scheduler_output .kv_connector_metadata )
1087-
1088- # Background KV cache transfers happen here.
1089- # These transfers are designed to be async and the requests
1090- # involved may be disjoint from the running requests.
1091- # Do this here to save a collective_rpc.
1092- kv_connector .start_load_kv (get_forward_context ())
1093-
1094- def maybe_wait_for_save ():
1095- if has_kv_transfer_group ():
1096- kv_connector = get_kv_transfer_group ()
1097- kv_connector .wait_for_save ()
1098-
1099- def maybe_get_finished () -> tuple [set [str ], set [str ]]:
1100- if has_kv_transfer_group ():
1101- kv_connector = get_kv_transfer_group ()
1102- return kv_connector .get_finished ()
1103- return set (), set ()
1104-
11051079 self ._update_states (scheduler_output )
11061080 if not scheduler_output .total_num_scheduled_tokens :
1107- # KV send/recv even if no work to do.
1108- with set_forward_context (None , self .vllm_config ):
1109- maybe_setup_kv_connector ()
1110- maybe_wait_for_save ()
1111- finished_sending , finished_recving = maybe_get_finished ()
1112-
1113- # Return empty ModelRunnerOutput if there's no work to do.
1114- output = EMPTY_MODEL_RUNNER_OUTPUT
1081+ if not has_kv_transfer_group ():
1082+ # Return empty ModelRunnerOutput if there's no work to do.
1083+ return EMPTY_MODEL_RUNNER_OUTPUT
11151084
1116- if len (finished_sending ) > 0 or len (finished_recving ) > 0 :
1117- output = copy .deepcopy (EMPTY_MODEL_RUNNER_OUTPUT )
1118- output .finished_sending = finished_sending
1119- output .finished_recving = finished_recving
1120- return output
1085+ return self .kv_connector_no_forward (scheduler_output )
11211086
11221087 # Prepare the decoder inputs.
11231088 attn_metadata , logits_indices , spec_decode_metadata = (
@@ -1194,7 +1159,7 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
11941159 with set_forward_context (attn_metadata ,
11951160 self .vllm_config ,
11961161 num_tokens = num_input_tokens ):
1197- maybe_setup_kv_connector ()
1162+ self . maybe_setup_kv_connector (scheduler_output )
11981163
11991164 model_output = self .model (
12001165 input_ids = input_ids ,
@@ -1203,8 +1168,9 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
12031168 inputs_embeds = inputs_embeds ,
12041169 )
12051170
1206- maybe_wait_for_save ()
1207- finished_sending , finished_recving = maybe_get_finished ()
1171+ self .maybe_wait_for_kv_save ()
1172+ finished_sending , finished_recving = (
1173+ self .get_finished_kv_transfers ())
12081174
12091175 if self .use_aux_hidden_state_outputs :
12101176 hidden_states , aux_hidden_states = model_output
@@ -1394,6 +1360,50 @@ def maybe_get_finished() -> tuple[set[str], set[str]]:
13941360 finished_recving = finished_recving ,
13951361 )
13961362
1363+ def kv_connector_no_forward (
1364+ self , scheduler_output : SchedulerOutput ) -> ModelRunnerOutput :
1365+ # KV send/recv even if no work to do.
1366+ with set_forward_context (None , self .vllm_config ):
1367+ self .maybe_setup_kv_connector (scheduler_output )
1368+ finished_sending , finished_recving = (
1369+ self .get_finished_kv_transfers ())
1370+
1371+ if not finished_sending and not finished_recving :
1372+ return EMPTY_MODEL_RUNNER_OUTPUT
1373+
1374+ output = copy .copy (EMPTY_MODEL_RUNNER_OUTPUT )
1375+ output .finished_sending = finished_sending
1376+ output .finished_recving = finished_recving
1377+ return output
1378+
1379+ @staticmethod
1380+ def maybe_setup_kv_connector (scheduler_output : SchedulerOutput ):
1381+ # Update KVConnector with the KVConnector metadata forward().
1382+ if has_kv_transfer_group ():
1383+ kv_connector = get_kv_transfer_group ()
1384+ assert isinstance (kv_connector , KVConnectorBase_V1 )
1385+ assert scheduler_output .kv_connector_metadata is not None
1386+ kv_connector .bind_connector_metadata (
1387+ scheduler_output .kv_connector_metadata )
1388+
1389+ # Background KV cache transfers happen here.
1390+ # These transfers are designed to be async and the requests
1391+ # involved may be disjoint from the running requests.
1392+ # Do this here to save a collective_rpc.
1393+ kv_connector .start_load_kv (get_forward_context ())
1394+
1395+ @staticmethod
1396+ def maybe_wait_for_kv_save () -> None :
1397+ if has_kv_transfer_group ():
1398+ get_kv_transfer_group ().wait_for_save ()
1399+
1400+ @staticmethod
1401+ def get_finished_kv_transfers (
1402+ ) -> tuple [Optional [set [str ]], Optional [set [str ]]]:
1403+ if has_kv_transfer_group ():
1404+ return get_kv_transfer_group ().get_finished ()
1405+ return None , None
1406+
13971407 def generate_draft_token_ids (
13981408 self ,
13991409 sampled_token_ids : list [list [int ]],
0 commit comments