|
2 | 2 | import os |
3 | 3 | import time |
4 | 4 | from functools import partial |
5 | | -from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, |
6 | | - Union, AsyncIterator, Callable) |
| 5 | +from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, |
| 6 | + Union, AsyncIterator) |
7 | 7 |
|
8 | 8 | from transformers import PreTrainedTokenizer |
9 | 9 |
|
10 | 10 | from vllm.lora.request import LoRARequest |
11 | 11 | from vllm.config import ModelConfig |
12 | 12 | from vllm.engine.arg_utils import AsyncEngineArgs |
13 | 13 | from vllm.engine.llm_engine import LLMEngine |
14 | | -from vllm.engine.ray_utils import initialize_cluster, ray |
| 14 | +from vllm.engine.ray_utils import initialize_ray_cluster, ray |
15 | 15 | from vllm.logger import init_logger |
16 | 16 | from vllm.outputs import RequestOutput |
17 | 17 | from vllm.sampling_params import SamplingParams |
@@ -208,17 +208,10 @@ async def step_async(self) -> List[RequestOutput]: |
208 | 208 |
|
209 | 209 | if not scheduler_outputs.is_empty(): |
210 | 210 | # Execute the model. |
211 | | - all_outputs = await self._run_workers_async( |
212 | | - "execute_model", |
213 | | - driver_kwargs={ |
214 | | - "seq_group_metadata_list": seq_group_metadata_list, |
215 | | - "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, |
216 | | - "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, |
217 | | - "blocks_to_copy": scheduler_outputs.blocks_to_copy, |
218 | | - }) |
219 | | - |
220 | | - # Only the driver worker returns the sampling results. |
221 | | - output = all_outputs[0] |
| 211 | + output = await self.model_executor.execute_model_async( |
| 212 | + seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, |
| 213 | + scheduler_outputs.blocks_to_swap_out, |
| 214 | + scheduler_outputs.blocks_to_copy) |
222 | 215 | else: |
223 | 216 | output = [] |
224 | 217 |
|
@@ -268,37 +261,8 @@ async def add_request_async( |
268 | 261 | lora_request=lora_request, |
269 | 262 | ) |
270 | 263 |
|
271 | | - async def _run_workers_async( |
272 | | - self, |
273 | | - method: str, |
274 | | - *args, |
275 | | - driver_args: Optional[List[Any]] = None, |
276 | | - driver_kwargs: Optional[Dict[str, Any]] = None, |
277 | | - **kwargs, |
278 | | - ) -> Any: |
279 | | - """Runs the given method on all workers.""" |
280 | | - coros = [] |
281 | | - |
282 | | - if driver_args is None: |
283 | | - driver_args = args |
284 | | - if driver_kwargs is None: |
285 | | - driver_kwargs = kwargs |
286 | | - |
287 | | - # Run the driver worker asynchronously. |
288 | | - driver_executor = getattr(self.driver_worker, method) |
289 | | - coros.append(asyncio.get_event_loop().run_in_executor( |
290 | | - None, partial(driver_executor, *driver_args, **driver_kwargs))) |
291 | | - |
292 | | - # Run the ray workers asynchronously. |
293 | | - for worker in self.workers: |
294 | | - coros.append(worker.execute_method.remote(method, *args, **kwargs)) |
295 | | - |
296 | | - all_outputs = await asyncio.gather(*coros) |
297 | | - return all_outputs |
298 | | - |
299 | | - async def check_health_async(self): |
300 | | - """Raises an error if engine is unhealthy.""" |
301 | | - self._check_if_any_actor_is_dead() |
| 264 | + async def check_health_async(self) -> None: |
| 265 | + self.model_executor.check_health() |
302 | 266 |
|
303 | 267 |
|
304 | 268 | class AsyncLLMEngine: |
@@ -353,6 +317,34 @@ def __init__(self, |
353 | 317 | self._request_tracker: Optional[RequestTracker] = None |
354 | 318 | self._errored_with: Optional[BaseException] = None |
355 | 319 |
|
| 320 | + @classmethod |
| 321 | + def from_engine_args(cls, |
| 322 | + engine_args: AsyncEngineArgs, |
| 323 | + start_engine_loop: bool = True) -> "AsyncLLMEngine": |
| 324 | + """Creates an async LLM engine from the engine arguments.""" |
| 325 | + # Create the engine configs. |
| 326 | + engine_configs = engine_args.create_engine_configs() |
| 327 | + parallel_config = engine_configs[2] |
| 328 | + if parallel_config.worker_use_ray or engine_args.engine_use_ray: |
| 329 | + initialize_ray_cluster(parallel_config) |
| 330 | + from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync |
| 331 | + executor_class = RayGPUExecutorAsync |
| 332 | + else: |
| 333 | + assert parallel_config.world_size == 1, ( |
| 334 | + "Ray is required if parallel_config.world_size > 1.") |
| 335 | + from vllm.executor.gpu_executor import GPUExecutorAsync |
| 336 | + executor_class = GPUExecutorAsync |
| 337 | + # Create the async LLM engine. |
| 338 | + engine = cls(parallel_config.worker_use_ray, |
| 339 | + engine_args.engine_use_ray, |
| 340 | + *engine_configs, |
| 341 | + executor_class, |
| 342 | + log_requests=not engine_args.disable_log_requests, |
| 343 | + log_stats=not engine_args.disable_log_stats, |
| 344 | + max_log_len=engine_args.max_log_len, |
| 345 | + start_engine_loop=start_engine_loop) |
| 346 | + return engine |
| 347 | + |
356 | 348 | @property |
357 | 349 | def is_running(self) -> bool: |
358 | 350 | return (self.background_loop is not None |
@@ -670,35 +662,13 @@ async def get_model_config(self) -> ModelConfig: |
670 | 662 | else: |
671 | 663 | return self.engine.get_model_config() |
672 | 664 |
|
673 | | - @classmethod |
674 | | - def from_engine_args(cls, |
675 | | - engine_args: AsyncEngineArgs, |
676 | | - start_engine_loop: bool = True) -> "AsyncLLMEngine": |
677 | | - """Creates an async LLM engine from the engine arguments.""" |
678 | | - # Create the engine configs. |
679 | | - engine_configs = engine_args.create_engine_configs() |
680 | | - parallel_config = engine_configs[2] |
681 | | - # Initialize the cluster. |
682 | | - placement_group = initialize_cluster(parallel_config, |
683 | | - engine_args.engine_use_ray) |
684 | | - # Create the async LLM engine. |
685 | | - engine = cls(parallel_config.worker_use_ray, |
686 | | - engine_args.engine_use_ray, |
687 | | - *engine_configs, |
688 | | - placement_group, |
689 | | - log_requests=not engine_args.disable_log_requests, |
690 | | - log_stats=not engine_args.disable_log_stats, |
691 | | - max_log_len=engine_args.max_log_len, |
692 | | - start_engine_loop=start_engine_loop) |
693 | | - return engine |
694 | | - |
695 | 665 | async def do_log_stats(self) -> None: |
696 | 666 | if self.engine_use_ray: |
697 | 667 | await self.engine.do_log_stats.remote() |
698 | 668 | else: |
699 | 669 | self.engine.do_log_stats() |
700 | 670 |
|
701 | | - async def check_health(self): |
| 671 | + async def check_health(self) -> None: |
702 | 672 | """Raises an error if engine is unhealthy.""" |
703 | 673 | t = time.perf_counter() |
704 | 674 | logger.debug("Starting health check...") |
|
0 commit comments