Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Eliminate parallel worker per-step task scheduling overhead #4894

Merged
merged 7 commits into from
May 22, 2024
6 changes: 5 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ async def step_async(
# Log stats.
self.do_log_stats(scheduler_outputs, output)

if not request_outputs:
njhill marked this conversation as resolved.
Show resolved Hide resolved
# Stop the execute model loop in parallel workers for now
await self.model_executor.stop_remote_worker_execution_loop_async()

return request_outputs

async def encode_request_async(
Expand Down Expand Up @@ -687,7 +691,7 @@ async def encode(
multi_modal_data: Multi modal data per request.

Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.

Details:
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Log stats.
self.do_log_stats(scheduler_outputs, output)

if not request_outputs:
# Stop the execute model loop in parallel workers for now
self.model_executor.stop_remote_worker_execution_loop()

return request_outputs

def do_log_stats(
Expand Down
83 changes: 58 additions & 25 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import asyncio
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union

from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput

logger = init_logger(__name__)


class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations."""

def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
njhill marked this conversation as resolved.
Show resolved Hide resolved
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self.extra_execute_model_run_workers_kwargs = {}
njhill marked this conversation as resolved.
Show resolved Hide resolved

super().__init__(*args, **kwargs)

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.

Expand Down Expand Up @@ -52,13 +63,17 @@ def initialize_cache(self, num_gpu_blocks: int,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)

def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
all_outputs = self._run_workers("execute_model",
driver_args=args,
driver_kwargs=kwargs)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
remote_workers_only_async=True,
**self.extra_execute_model_run_workers_kwargs)

# Only the driver worker returns the sampling results.
return all_outputs[0]
return self._driver_execute_model(execute_model_req)

def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
Expand Down Expand Up @@ -88,13 +103,19 @@ def save_sharded_state(
pattern=pattern,
max_size=max_size)

@abstractmethod
def _driver_execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Run execute_model in the driver worker."""
raise NotImplementedError

@abstractmethod
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
remote_workers_only_async: bool = False,
njhill marked this conversation as resolved.
Show resolved Hide resolved
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
Expand All @@ -104,23 +125,35 @@ def _run_workers(

class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):

async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())

# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)

async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert instead? If this is None, doesn't that mean the state is kind of screwed up?

return

await self._driver_execute_model_async()
njhill marked this conversation as resolved.
Show resolved Hide resolved
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks

@abstractmethod
async def _run_workers_async(
async def _driver_execute_model_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
raise NotImplementedError

async def execute_model_async(self, *args,
**kwargs) -> List[SamplerOutput]:
all_outputs = await self._run_workers_async("execute_model",
driver_args=args,
driver_kwargs=kwargs)

# Only the driver worker returns the sampling results.
return all_outputs[0]
@abstractmethod
async def _start_worker_execution_loop(self):
njhill marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError
8 changes: 8 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def execute_model(
"""Executes at least one model step on the given sequences."""
raise NotImplementedError

def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return

@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
Expand Down Expand Up @@ -109,6 +113,10 @@ async def execute_model_async(
"""Executes one model step on the given sequences."""
raise NotImplementedError

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return

async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
Expand Down
51 changes: 24 additions & 27 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import os
from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, List, Optional

from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)

Expand Down Expand Up @@ -71,12 +72,17 @@ def shutdown(self):
None)) is not None:
worker_monitor.close()

def _driver_execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)

def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
remote_workers_only_async: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
Expand All @@ -92,15 +98,13 @@ def _run_workers(
for worker in self.workers
]

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
if remote_workers_only_async:
# Just return futures
return worker_outputs

# Start the driver worker after all the ray workers.
njhill marked this conversation as resolved.
Show resolved Hide resolved
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*driver_args,
**driver_kwargs)
driver_worker_output = driver_worker_method(*args, **kwargs)

# Get the results of the workers.
return [driver_worker_output
Expand All @@ -115,26 +119,19 @@ def check_health(self) -> None:
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):

async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)

driver_executor = make_async(getattr(self.driver_worker, method))
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)

# Run all the workers asynchronously.
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
worker.execute_method_async(method, *args, **kwargs)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
]

return await asyncio.gather(*coros)
69 changes: 27 additions & 42 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def _init_executor(self) -> None:
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True

def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
Expand Down Expand Up @@ -171,23 +173,17 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

def execute_model(
def _driver_execute_model(
njhill marked this conversation as resolved.
Show resolved Hide resolved
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={"execute_model_req": execute_model_req},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)

# Only the driver worker returns the sampling results.
return all_outputs[0]
return self.driver_worker.execute_method("execute_model",
execute_model_req)

def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
remote_workers_only_async: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
Expand All @@ -199,8 +195,6 @@ def _run_workers(
ways:

- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
Expand All @@ -209,11 +203,6 @@ def _run_workers(
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
Expand All @@ -225,6 +214,7 @@ def _run_workers(
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else:
# Start the ray workers first.
ray_worker_outputs = [
Expand All @@ -234,6 +224,13 @@ def _run_workers(
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]

if remote_workers_only_async:
# Just return futures
return ray_worker_outputs

driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
Expand Down Expand Up @@ -303,30 +300,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)
self.driver_exec_method = make_async(self.driver_worker.execute_method)

async def _run_workers_async(
async def _driver_execute_model_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

coros.append(
self.driver_executor(method, *driver_args, **driver_kwargs))

# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))

all_outputs = await asyncio.gather(*coros)
return all_outputs
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)

async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)
Loading
Loading