33import pickle
44from collections import defaultdict
55from itertools import islice , repeat
6- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple
6+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
77
8- from vllm .executor .executor_base import ExecutorAsyncBase , ExecutorBase
8+ from vllm .executor .distributed_gpu_executor import ( # yapf: disable
9+ DistributedGPUExecutor , DistributedGPUExecutorAsync )
910from vllm .executor .ray_utils import RayWorkerWrapper , ray
1011from vllm .logger import init_logger
11- from vllm .lora .request import LoRARequest
1212from vllm .sequence import SamplerOutput , SequenceGroupMetadata
1313from vllm .utils import (get_distributed_init_method , get_ip , get_open_port ,
1414 get_vllm_instance_id , make_async )
2727USE_RAY_COMPILED_DAG = bool (os .getenv ("VLLM_USE_RAY_COMPILED_DAG" , 0 ))
2828
2929
30- class RayGPUExecutor (ExecutorBase ):
30+ class RayGPUExecutor (DistributedGPUExecutor ):
3131
3232 def _init_executor (self ) -> None :
3333 assert (not self .speculative_config
@@ -179,50 +179,9 @@ def collect_arg_helper_func(**kwargs):
179179 self ._run_workers ("init_worker" , all_kwargs = init_worker_all_kwargs )
180180
181181 self ._run_workers ("init_device" )
182- self ._run_workers (
183- "load_model" ,
184- max_concurrent_workers = self .parallel_config .
185- max_parallel_loading_workers ,
186- )
187-
188- def determine_num_available_blocks (self ) -> Tuple [int , int ]:
189- """Determine the number of available KV blocks.
190-
191- This invokes `determine_num_available_blocks` on each worker and takes
192- the min of the results, guaranteeing that the selected cache sizes are
193- compatible with all workers.
194-
195- Returns:
196- - Tuple[num_gpu_blocks, num_cpu_blocks]
197- """
198- # Get the maximum number of blocks that can be allocated on GPU and CPU.
199- num_blocks = self ._run_workers ("determine_num_available_blocks" , )
200-
201- # Since we use a shared centralized controller, we take the minimum
202- # number of blocks across all workers to make sure all the memory
203- # operators can be applied to all workers.
204- num_gpu_blocks = min (b [0 ] for b in num_blocks )
205- num_cpu_blocks = min (b [1 ] for b in num_blocks )
206-
207- return num_gpu_blocks , num_cpu_blocks
208-
209- def initialize_cache (self , num_gpu_blocks : int ,
210- num_cpu_blocks : int ) -> None :
211- """Initialize the KV cache in all workers.
212- """
213-
214- # NOTE: We log here to avoid multiple logs when number of workers is
215- # greater than one. We could log in the engine, but not all executors
216- # have GPUs.
217- logger .info ("# GPU blocks: %d, # CPU blocks: %d" , num_gpu_blocks ,
218- num_cpu_blocks )
219-
220- self .cache_config .num_gpu_blocks = num_gpu_blocks
221- self .cache_config .num_cpu_blocks = num_cpu_blocks
222-
223- self ._run_workers ("initialize_cache" ,
224- num_gpu_blocks = num_gpu_blocks ,
225- num_cpu_blocks = num_cpu_blocks )
182+ self ._run_workers ("load_model" ,
183+ max_concurrent_workers = self .parallel_config .
184+ max_parallel_loading_workers )
226185
227186 def execute_model (self ,
228187 seq_group_metadata_list : List [SequenceGroupMetadata ],
@@ -244,23 +203,6 @@ def execute_model(self,
244203 output = all_outputs [0 ]
245204 return output
246205
247- def add_lora (self , lora_request : LoRARequest ) -> bool :
248- assert lora_request .lora_int_id > 0 , "lora_id must be greater than 0."
249- return self ._run_workers (
250- "add_lora" ,
251- lora_request = lora_request ,
252- )
253-
254- def remove_lora (self , lora_id : int ) -> bool :
255- assert lora_id > 0 , "lora_id must be greater than 0."
256- return self ._run_workers (
257- "remove_lora" ,
258- lora_id = lora_id ,
259- )
260-
261- def list_loras (self ) -> Set [int ]:
262- return self ._run_workers ("list_loras" )
263-
264206 def _run_workers (
265207 self ,
266208 method : str ,
@@ -378,7 +320,7 @@ def _check_if_any_actor_is_dead(self):
378320 f"Dead Workers: { dead_actors } . " )
379321
380322
381- class RayGPUExecutorAsync (RayGPUExecutor , ExecutorAsyncBase ):
323+ class RayGPUExecutorAsync (RayGPUExecutor , DistributedGPUExecutorAsync ):
382324
383325 def __init__ (self , * args , ** kwargs ):
384326 super ().__init__ (* args , ** kwargs )
@@ -409,23 +351,3 @@ async def _run_workers_async(
409351
410352 all_outputs = await asyncio .gather (* coros )
411353 return all_outputs
412-
413- async def execute_model_async (
414- self ,
415- seq_group_metadata_list : List [SequenceGroupMetadata ],
416- blocks_to_swap_in : Dict [int , int ],
417- blocks_to_swap_out : Dict [int , int ],
418- blocks_to_copy : Dict [int , List [int ]],
419- ) -> SamplerOutput :
420- all_outputs = await self ._run_workers_async (
421- "execute_model" ,
422- driver_kwargs = {
423- "seq_group_metadata_list" : seq_group_metadata_list ,
424- "blocks_to_swap_in" : blocks_to_swap_in ,
425- "blocks_to_swap_out" : blocks_to_swap_out ,
426- "blocks_to_copy" : blocks_to_copy ,
427- })
428-
429- # Only the driver worker returns the sampling results.
430- output = all_outputs [0 ]
431- return output
0 commit comments