Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 23a9ded

Browse files
rkooo567alexm-redhat
authored andcommitted
[Ray] Integration compiled DAG off by default (vllm-project#2471)
1 parent 2da4b50 commit 23a9ded

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

vllm/engine/llm_engine.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import defaultdict
33
import os
44
import time
5+
import pickle
56
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
67
Union)
78

@@ -30,6 +31,11 @@
3031
logger = init_logger(__name__)
3132
_LOCAL_LOGGING_INTERVAL_SEC = 5
3233

34+
# If the env var is set, it uses the Ray's compiled DAG API
35+
# which optimizes the control plane overhead.
36+
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
37+
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
38+
3339

3440
class LLMEngine:
3541
"""An LLM engine that receives requests and generates texts.
@@ -125,6 +131,10 @@ def __init__(
125131
self.stat_logger = StatLogger(
126132
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
127133

134+
self.forward_dag = None
135+
if USE_RAY_COMPILED_DAG:
136+
self.forward_dag = self._compiled_ray_dag()
137+
128138
def get_tokenizer_for_seq(self, sequence: Sequence):
129139
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
130140

@@ -807,7 +817,8 @@ def step(self) -> List[RequestOutput]:
807817
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
808818
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
809819
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
810-
})
820+
},
821+
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
811822

812823
# Only the driver worker returns the sampling results.
813824
output = all_outputs[0]
@@ -967,6 +978,7 @@ def _run_workers(
967978
driver_args: Optional[List[Any]] = None,
968979
driver_kwargs: Optional[Dict[str, Any]] = None,
969980
max_concurrent_workers: Optional[int] = None,
981+
use_ray_compiled_dag: bool = False,
970982
**kwargs,
971983
) -> Any:
972984
"""Runs the given method on all workers."""
@@ -975,11 +987,16 @@ def _run_workers(
975987
raise NotImplementedError(
976988
"max_concurrent_workers is not supported yet.")
977989

978-
# Start the ray workers first.
979-
ray_worker_outputs = [
980-
worker.execute_method.remote(method, *args, **kwargs)
981-
for worker in self.workers
982-
]
990+
if use_ray_compiled_dag:
991+
# Right now, compiled DAG can only accept a single
992+
# input. TODO(sang): Fix it.
993+
output_channels = self.forward_dag.execute(1)
994+
else:
995+
# Start the ray workers first.
996+
ray_worker_outputs = [
997+
worker.execute_method.remote(method, *args, **kwargs)
998+
for worker in self.workers
999+
]
9831000

9841001
if driver_args is None:
9851002
driver_args = args
@@ -992,6 +1009,37 @@ def _run_workers(
9921009

9931010
# Get the results of the ray workers.
9941011
if self.workers:
995-
ray_worker_outputs = ray.get(ray_worker_outputs)
1012+
if use_ray_compiled_dag:
1013+
try:
1014+
ray_worker_outputs = [
1015+
pickle.loads(chan.begin_read())
1016+
for chan in output_channels
1017+
]
1018+
finally:
1019+
# Has to call end_read in order to reuse the DAG.
1020+
for chan in output_channels:
1021+
chan.end_read()
1022+
else:
1023+
ray_worker_outputs = ray.get(ray_worker_outputs)
9961024

9971025
return [driver_worker_output] + ray_worker_outputs
1026+
1027+
def _compiled_ray_dag(self):
1028+
import pkg_resources
1029+
required_version = "2.9"
1030+
current_version = pkg_resources.get_distribution("ray").version
1031+
if current_version < required_version:
1032+
raise ValueError(f"Ray version {required_version} or greater is "
1033+
f"required, but found {current_version}")
1034+
1035+
from ray.dag import MultiOutputNode, InputNode
1036+
assert self.parallel_config.worker_use_ray
1037+
1038+
# Right now, compiled DAG requires at least 1 arg. We send
1039+
# a dummy value for now. It will be fixed soon.
1040+
with InputNode() as input_data:
1041+
forward_dag = MultiOutputNode([
1042+
worker.execute_model_compiled_dag_remote.bind(input_data)
1043+
for worker in self.workers
1044+
])
1045+
return forward_dag.experimental_compile()

vllm/engine/ray_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pickle
2+
13
from typing import Optional, List, Tuple, TYPE_CHECKING
24

35
from vllm.config import ParallelConfig
@@ -18,6 +20,11 @@ def __init__(self, init_cached_hf_modules=False) -> None:
1820
from transformers.dynamic_module_utils import init_hf_modules
1921
init_hf_modules()
2022
self.worker = None
23+
# Since the compiled DAG runs a main execution
24+
# in a different thread that calls cuda.set_device.
25+
# The flag indicates is set_device is called on
26+
# that thread.
27+
self.compiled_dag_cuda_device_set = False
2128

2229
def init_worker(self, worker_init_fn):
2330
self.worker = worker_init_fn()
@@ -40,6 +47,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
4047
def set_cuda_visible_devices(self, device_ids) -> None:
4148
set_cuda_visible_devices(device_ids)
4249

50+
def execute_model_compiled_dag_remote(self, ignored):
51+
"""Used only when compiled DAG is enabled."""
52+
import torch
53+
if not self.compiled_dag_cuda_device_set:
54+
torch.cuda.set_device(self.worker.device)
55+
self.compiled_dag_cuda_device_set = True
56+
57+
output = self.worker.execute_model()
58+
output = pickle.dumps(output)
59+
return output
60+
4361
except ImportError as e:
4462
logger.warning(f"Failed to import Ray with {e!r}. "
4563
"For distributed inference, please install Ray with "

0 commit comments

Comments
 (0)