Skip to content

[Core] Eliminate parallel worker per-step task scheduling overhead #3763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def broadcast_tensor_dict(
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"

# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict

ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ async def step_async(self) -> List[RequestOutput]:
else:
output = []

return self._process_model_outputs(output, scheduler_outputs)
outputs = self._process_model_outputs(output, scheduler_outputs)
if not outputs:
# Stop the execute model loop in parallel workers for now
await self.model_executor.stop_remote_worker_execution_loop_async()
return outputs

async def encode_request_async(
self,
Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,11 @@ def step(self) -> List[RequestOutput]:
else:
output = []

return self._process_model_outputs(output, scheduler_outputs)
outputs = self._process_model_outputs(output, scheduler_outputs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will return finished outputs twice? LLM object will get duplicate output of same request?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hengxinCheung sorry, I'm not sure I understand the question. This PR doesn't change anything w.r.t. how many outputs are returned.

Copy link

@hengxinCheung hengxinCheung Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry for cofusing you. Let me provide a more detailed description. For example, request A marked as finished in the current execution, but it will be scheduled in the next step. So this request will return last generated text twice? I will carefully read your implementation again. Thanks your reply.

if not outputs:
# Stop the execute model loop in parallel workers for now
self.model_executor.stop_remote_worker_execution_loop()
return outputs

def do_log_stats(self) -> None:
"""Forced log when no requests active."""
Expand Down
8 changes: 8 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def execute_model(self,
"""Executes one model step on the given sequences."""
raise NotImplementedError

def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return

@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
Expand Down Expand Up @@ -108,6 +112,10 @@ async def execute_model_async(
"""Executes one model step on the given sequences."""
raise NotImplementedError

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return

async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
Expand Down
119 changes: 62 additions & 57 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union)

from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
Expand Down Expand Up @@ -36,6 +37,10 @@ def _init_executor(self) -> None:
assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group

# This is non-None when the execute model loop is running
# in the parallel workers
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None

# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
Expand Down Expand Up @@ -223,19 +228,30 @@ def execute_model(self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
remote_workers_only_async=True,
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)

# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
return self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)

def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return

self.driver_worker.execute_model()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
ray.get(parallel_worker_tasks)

def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
Expand All @@ -258,8 +274,7 @@ def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
remote_workers_only_async: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
Expand All @@ -275,21 +290,21 @@ def _run_workers(
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
if remote_workers_only_async:
# Just return futures
return ray_worker_outputs

# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)
driver_worker_output = getattr(self.driver_worker, method)(*args,
**kwargs)

# Get the results of the ray workers.
if self.workers:
Expand Down Expand Up @@ -348,49 +363,39 @@ def _check_if_any_actor_is_dead(self):

class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):

async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

# Run the driver worker asynchronously.
driver_executor = make_async(getattr(self.driver_worker, method))
coros.append(driver_executor(*driver_args, **driver_kwargs))

# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))

all_outputs = await asyncio.gather(*coros)
return all_outputs

async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
})
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())

# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
return await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)

async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)

async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return

await make_async(self.driver_worker.execute_model)()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
43 changes: 30 additions & 13 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,12 @@ def execute_model(
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
assert self.is_driver_worker
if seq_group_metadata_list is None:
# No data to run, notify other workers to stop the execution loop.
num_seq_groups = 0
data = {}
else:
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
Expand All @@ -223,23 +227,36 @@ def execute_model(
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
broadcast_tensor_dict(data, src=0)

self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)

# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
return None

return self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)

@torch.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker."""
assert not self.is_driver_worker
while True:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data.get("num_seq_groups", 0)
blocks_to_swap_in = data.get("blocks_to_swap_in")
blocks_to_swap_out = data.get("blocks_to_swap_out")
blocks_to_copy = data.get("blocks_to_copy")

self.cache_swap(blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy)

# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return None

output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)
return output
self.model_runner.execute_model(None, self.gpu_cache)

def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
Expand Down