forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hardware][TPU] Implement tensor parallelism with Ray (vllm-project#5871
- Loading branch information
1 parent
c6c5cc2
commit 08963df
Showing
6 changed files
with
365 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,313 @@ | ||
import asyncio | ||
import os | ||
from collections import defaultdict | ||
from itertools import islice, repeat | ||
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, | ||
Union) | ||
|
||
import vllm.envs as envs | ||
from vllm.executor.executor_base import ExecutorAsyncBase | ||
from vllm.executor.ray_utils import RayWorkerWrapper, ray | ||
from vllm.executor.tpu_executor import TPUExecutor | ||
from vllm.logger import init_logger | ||
from vllm.sequence import ExecuteModelRequest, SamplerOutput | ||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, | ||
get_vllm_instance_id, make_async) | ||
|
||
if ray is not None: | ||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy | ||
|
||
if TYPE_CHECKING: | ||
from ray.util.placement_group import PlacementGroup | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class RayTPUExecutor(TPUExecutor): | ||
|
||
def __init__(self, *args, **kwargs): | ||
# This is non-None when the execute model loop is running | ||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case. | ||
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None | ||
# Updated by implementations that require additional args to be passed | ||
# to the _run_workers execute_model call | ||
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} | ||
|
||
super().__init__(*args, **kwargs) | ||
|
||
def _init_executor(self) -> None: | ||
assert self.parallel_config.distributed_executor_backend == "ray" | ||
placement_group = self.parallel_config.placement_group | ||
|
||
# Disable Ray usage stats collection. | ||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") | ||
if ray_usage != "1": | ||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0" | ||
|
||
# Create the parallel TPU workers. | ||
self._init_workers_ray(placement_group) | ||
|
||
def _init_workers_ray(self, placement_group: "PlacementGroup", | ||
**ray_remote_kwargs): | ||
# The driver dummy worker does not actually use any resources. | ||
# It holds the resource for the driver worker. | ||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None | ||
# The remaining workers are the actual ray actors. | ||
self.workers: List[RayWorkerWrapper] = [] | ||
|
||
# Create the workers. | ||
driver_ip = get_ip() | ||
for bundle_id, bundle in enumerate(placement_group.bundle_specs): | ||
if not bundle.get("TPU", 0): | ||
continue | ||
scheduling_strategy = PlacementGroupSchedulingStrategy( | ||
placement_group=placement_group, | ||
placement_group_capture_child_tasks=True, | ||
placement_group_bundle_index=bundle_id, | ||
) | ||
|
||
assert self.speculative_config is None | ||
worker_module_name = "vllm.worker.tpu_worker" | ||
worker_class_name = "TPUWorker" | ||
|
||
worker = ray.remote( | ||
num_cpus=0, | ||
resources={"TPU": 1}, | ||
scheduling_strategy=scheduling_strategy, | ||
**ray_remote_kwargs, | ||
)(RayWorkerWrapper).remote( | ||
worker_module_name=worker_module_name, | ||
worker_class_name=worker_class_name, | ||
trust_remote_code=self.model_config.trust_remote_code, | ||
) | ||
|
||
worker_ip = ray.get(worker.get_node_ip.remote()) | ||
if worker_ip == driver_ip and self.driver_dummy_worker is None: | ||
# If the worker is on the same node as the driver, we use it | ||
# as the resource holder for the driver process. | ||
self.driver_dummy_worker = worker | ||
self.driver_worker = RayWorkerWrapper( | ||
worker_module_name=worker_module_name, | ||
worker_class_name=worker_class_name, | ||
trust_remote_code=self.model_config.trust_remote_code, | ||
) | ||
else: | ||
# Else, added to the list of workers. | ||
self.workers.append(worker) | ||
|
||
if self.driver_dummy_worker is None: | ||
raise ValueError( | ||
"Ray does not allocate any TPUs on the driver node. Consider " | ||
"adjusting the Ray placement group or running the driver on a " | ||
"TPU node.") | ||
|
||
# Get the set of TPU IDs used on each node. | ||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", | ||
use_dummy_driver=True) | ||
|
||
node_workers = defaultdict(list) | ||
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): | ||
node_workers[node_id].append(i) | ||
|
||
VLLM_INSTANCE_ID = get_vllm_instance_id() | ||
|
||
# Set environment variables for the driver and workers. | ||
all_args_to_update_environment_variables = [({ | ||
"VLLM_INSTANCE_ID": | ||
VLLM_INSTANCE_ID, | ||
"VLLM_TRACE_FUNCTION": | ||
str(envs.VLLM_TRACE_FUNCTION), | ||
}, ) for _ in worker_node_and_gpu_ids] | ||
self._run_workers("update_environment_variables", | ||
all_args=all_args_to_update_environment_variables) | ||
|
||
if len(node_workers) == 1: | ||
# in single node case, we don't need to get the IP address. | ||
# the loopback address is sufficient | ||
# NOTE: a node may have several IP addresses, one for each | ||
# network interface. `get_ip()` might return any of them, | ||
# while they might not work for communication inside the node | ||
# if the network setup is complicated. Using the loopback address | ||
# solves this issue, as it always works for communication inside | ||
# the node. | ||
driver_ip = "127.0.0.1" | ||
distributed_init_method = get_distributed_init_method( | ||
driver_ip, get_open_port()) | ||
|
||
# Initialize the actual workers inside worker wrapper. | ||
init_worker_all_kwargs = [ | ||
self._get_worker_kwargs( | ||
local_rank=node_workers[node_id].index(rank), | ||
rank=rank, | ||
distributed_init_method=distributed_init_method, | ||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) | ||
] | ||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) | ||
|
||
self._run_workers("init_device") | ||
self._run_workers("load_model", | ||
max_concurrent_workers=self.parallel_config. | ||
max_parallel_loading_workers) | ||
|
||
def _driver_execute_model( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> List[SamplerOutput]: | ||
"""Run execute_model in the driver worker. | ||
Passing None will cause the driver to stop the model execution | ||
loop running in each of the remote workers. | ||
""" | ||
return self.driver_worker.execute_method("execute_model", | ||
execute_model_req) | ||
|
||
def _run_workers( | ||
self, | ||
method: str, | ||
*args, | ||
async_run_remote_workers_only: bool = False, | ||
all_args: Optional[List[Tuple[Any, ...]]] = None, | ||
all_kwargs: Optional[List[Dict[str, Any]]] = None, | ||
use_dummy_driver: bool = False, | ||
max_concurrent_workers: Optional[int] = None, | ||
use_ray_compiled_dag: bool = False, | ||
**kwargs, | ||
) -> Any: | ||
"""Runs the given method on all workers. Can be used in the following | ||
ways: | ||
- async_run_remote_workers_only: If True the method will be run only | ||
in the remote workers, not the driver worker. It will also be | ||
run asynchronously and return a list of futures rather than blocking | ||
on the results. | ||
- args/kwargs: All workers share the same args/kwargs | ||
- all_args/all_kwargs: args/kwargs for each worker are specified | ||
individually | ||
""" | ||
|
||
if max_concurrent_workers: | ||
raise NotImplementedError( | ||
"max_concurrent_workers is not supported yet.") | ||
|
||
count = len(self.workers) | ||
all_worker_args = repeat(args, count) if all_args is None \ | ||
else islice(all_args, 1, None) | ||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ | ||
else islice(all_kwargs, 1, None) | ||
|
||
# Start the ray workers first. | ||
ray_worker_outputs = [ | ||
worker.execute_method.remote(method, *worker_args, **worker_kwargs) | ||
for (worker, worker_args, worker_kwargs | ||
) in zip(self.workers, all_worker_args, all_worker_kwargs) | ||
] | ||
|
||
if async_run_remote_workers_only: | ||
# Just return futures | ||
return ray_worker_outputs | ||
|
||
driver_args = args if all_args is None else all_args[0] | ||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] | ||
|
||
# Start the driver worker after all the ray workers. | ||
if not use_dummy_driver: | ||
driver_worker_output = self.driver_worker.execute_method( | ||
method, *driver_args, **driver_kwargs) | ||
else: | ||
assert self.driver_dummy_worker is not None | ||
driver_worker_output = ray.get( | ||
self.driver_dummy_worker.execute_method.remote( | ||
method, *driver_args, **driver_kwargs)) | ||
# Get the results of the ray workers. | ||
if self.workers: | ||
ray_worker_outputs = ray.get(ray_worker_outputs) | ||
|
||
return [driver_worker_output] + ray_worker_outputs | ||
|
||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: | ||
"""Wait for futures returned from _run_workers() with | ||
async_run_remote_workers_only to complete.""" | ||
ray.get(parallel_worker_tasks) | ||
|
||
def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
num_blocks = self._run_workers("determine_num_available_blocks", ) | ||
num_tpu_blocks = min(b[0] for b in num_blocks) | ||
num_cpu_blocks = min(b[1] for b in num_blocks) | ||
return num_tpu_blocks, num_cpu_blocks | ||
|
||
def initialize_cache(self, num_gpu_blocks: int, | ||
num_cpu_blocks: int) -> None: | ||
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, | ||
num_cpu_blocks) | ||
self.cache_config.num_gpu_blocks = num_gpu_blocks | ||
self.cache_config.num_cpu_blocks = num_cpu_blocks | ||
self._run_workers("initialize_cache", | ||
num_gpu_blocks=num_gpu_blocks, | ||
num_cpu_blocks=num_cpu_blocks) | ||
|
||
def execute_model( | ||
self, | ||
execute_model_req: ExecuteModelRequest, | ||
) -> List[SamplerOutput]: | ||
if self.parallel_worker_tasks is None: | ||
self.parallel_worker_tasks = self._run_workers( | ||
"start_worker_execution_loop", | ||
async_run_remote_workers_only=True, | ||
**self.extra_execute_model_run_workers_kwargs) | ||
|
||
# Only the driver worker returns the sampling results. | ||
return self._driver_execute_model(execute_model_req) | ||
|
||
def stop_remote_worker_execution_loop(self) -> None: | ||
if self.parallel_worker_tasks is None: | ||
return | ||
|
||
self._driver_execute_model() | ||
parallel_worker_tasks = self.parallel_worker_tasks | ||
self.parallel_worker_tasks = None | ||
# Ensure that workers exit model loop cleanly | ||
# (this will raise otherwise) | ||
self._wait_for_tasks_completion(parallel_worker_tasks) | ||
|
||
|
||
class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.driver_exec_method = make_async(self.driver_worker.execute_method) | ||
|
||
async def execute_model_async( | ||
self, | ||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: | ||
if self.parallel_worker_tasks is None: | ||
# Start model execution loop running in the parallel workers | ||
self.parallel_worker_tasks = asyncio.create_task( | ||
self._start_worker_execution_loop()) | ||
|
||
# Only the driver worker returns the sampling results. | ||
return await self._driver_execute_model_async(execute_model_req) | ||
|
||
async def stop_remote_worker_execution_loop_async(self) -> None: | ||
if self.parallel_worker_tasks is None: | ||
return | ||
|
||
await self._driver_execute_model_async() | ||
parallel_worker_tasks = self.parallel_worker_tasks | ||
self.parallel_worker_tasks = None | ||
# Ensure that workers exit model loop cleanly | ||
# (this will raise otherwise) | ||
await parallel_worker_tasks | ||
|
||
async def _driver_execute_model_async( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> List[SamplerOutput]: | ||
return await self.driver_exec_method("execute_model", | ||
execute_model_req) | ||
|
||
async def _start_worker_execution_loop(self): | ||
coros = [ | ||
worker.execute_method.remote("start_worker_execution_loop") | ||
for worker in self.workers | ||
] | ||
return await asyncio.gather(*coros) |
Oops, something went wrong.