File tree Expand file tree Collapse file tree 5 files changed +13
-4
lines changed Expand file tree Collapse file tree 5 files changed +13
-4
lines changed Original file line number Diff line number Diff line change @@ -211,9 +211,11 @@ async def step_async(self) -> List[RequestOutput]:
211211 if not scheduler_outputs .is_empty ():
212212 # Execute the model.
213213 output = await self .model_executor .execute_model_async (
214- seq_group_metadata_list , scheduler_outputs .blocks_to_swap_in ,
214+ seq_group_metadata_list ,
215+ scheduler_outputs .blocks_to_swap_in ,
215216 scheduler_outputs .blocks_to_swap_out ,
216- scheduler_outputs .blocks_to_copy )
217+ scheduler_outputs .blocks_to_copy ,
218+ num_lookahead_slots = scheduler_outputs .num_lookahead_slots )
217219 else :
218220 output = []
219221
Original file line number Diff line number Diff line change @@ -105,6 +105,7 @@ async def execute_model_async(
105105 blocks_to_swap_in : Dict [int , int ],
106106 blocks_to_swap_out : Dict [int , int ],
107107 blocks_to_copy : Dict [int , List [int ]],
108+ num_lookahead_slots : int ,
108109 ) -> SamplerOutput :
109110 """Executes one model step on the given sequences."""
110111 raise NotImplementedError
Original file line number Diff line number Diff line change @@ -162,10 +162,12 @@ async def execute_model_async(
162162 blocks_to_swap_in : Dict [int , int ],
163163 blocks_to_swap_out : Dict [int , int ],
164164 blocks_to_copy : Dict [int , List [int ]],
165+ num_lookahead_slots : int ,
165166 ) -> SamplerOutput :
166167 output = await make_async (self .driver_worker .execute_model )(
167168 seq_group_metadata_list = seq_group_metadata_list ,
168169 blocks_to_swap_in = blocks_to_swap_in ,
169170 blocks_to_swap_out = blocks_to_swap_out ,
170- blocks_to_copy = blocks_to_copy )
171+ blocks_to_copy = blocks_to_copy ,
172+ num_lookahead_slots = num_lookahead_slots )
171173 return output
Original file line number Diff line number Diff line change @@ -84,9 +84,11 @@ async def execute_model_async(
8484 blocks_to_swap_in : Dict [int , int ],
8585 blocks_to_swap_out : Dict [int , int ],
8686 blocks_to_copy : Dict [int , List [int ]],
87+ num_lookahead_slots : int ,
8788 ) -> SamplerOutput :
8889 output = await make_async (self .driver_worker .execute_model )(
89- seq_group_metadata_list = seq_group_metadata_list , )
90+ seq_group_metadata_list = seq_group_metadata_list ,
91+ num_lookahead_slots = num_lookahead_slots )
9092 return output
9193
9294 async def check_health_async (self ) -> None :
Original file line number Diff line number Diff line change @@ -420,6 +420,7 @@ async def execute_model_async(
420420 blocks_to_swap_in : Dict [int , int ],
421421 blocks_to_swap_out : Dict [int , int ],
422422 blocks_to_copy : Dict [int , List [int ]],
423+ num_lookahead_slots : int ,
423424 ) -> SamplerOutput :
424425 all_outputs = await self ._run_workers_async (
425426 "execute_model" ,
@@ -428,6 +429,7 @@ async def execute_model_async(
428429 "blocks_to_swap_in" : blocks_to_swap_in ,
429430 "blocks_to_swap_out" : blocks_to_swap_out ,
430431 "blocks_to_copy" : blocks_to_copy ,
432+ "num_lookahead_slots" : num_lookahead_slots ,
431433 })
432434
433435 # Only the driver worker returns the sampling results.
You can’t perform that action at this time.
0 commit comments