Skip to content

Commit 65b89d1

Browse files
authored
[Ray] Integration compiled DAG off by default (#2471)
1 parent 931746b commit 65b89d1

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

vllm/engine/llm_engine.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import defaultdict
33
import os
44
import time
5+
import pickle
56
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
67
Union)
78

@@ -30,6 +31,11 @@
3031
logger = init_logger(__name__)
3132
_LOCAL_LOGGING_INTERVAL_SEC = 5
3233

34+
# If the env var is set, it uses the Ray's compiled DAG API
35+
# which optimizes the control plane overhead.
36+
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
37+
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
38+
3339

3440
class LLMEngine:
3541
"""An LLM engine that receives requests and generates texts.
@@ -124,6 +130,10 @@ def __init__(
124130
self.stat_logger = StatLogger(
125131
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
126132

133+
self.forward_dag = None
134+
if USE_RAY_COMPILED_DAG:
135+
self.forward_dag = self._compiled_ray_dag()
136+
127137
def get_tokenizer_for_seq(self, sequence: Sequence):
128138
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
129139

@@ -806,7 +816,8 @@ def step(self) -> List[RequestOutput]:
806816
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
807817
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
808818
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
809-
})
819+
},
820+
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
810821

811822
# Only the driver worker returns the sampling results.
812823
output = all_outputs[0]
@@ -966,6 +977,7 @@ def _run_workers(
966977
driver_args: Optional[List[Any]] = None,
967978
driver_kwargs: Optional[Dict[str, Any]] = None,
968979
max_concurrent_workers: Optional[int] = None,
980+
use_ray_compiled_dag: bool = False,
969981
**kwargs,
970982
) -> Any:
971983
"""Runs the given method on all workers."""
@@ -974,11 +986,16 @@ def _run_workers(
974986
raise NotImplementedError(
975987
"max_concurrent_workers is not supported yet.")
976988

977-
# Start the ray workers first.
978-
ray_worker_outputs = [
979-
worker.execute_method.remote(method, *args, **kwargs)
980-
for worker in self.workers
981-
]
989+
if use_ray_compiled_dag:
990+
# Right now, compiled DAG can only accept a single
991+
# input. TODO(sang): Fix it.
992+
output_channels = self.forward_dag.execute(1)
993+
else:
994+
# Start the ray workers first.
995+
ray_worker_outputs = [
996+
worker.execute_method.remote(method, *args, **kwargs)
997+
for worker in self.workers
998+
]
982999

9831000
if driver_args is None:
9841001
driver_args = args
@@ -991,6 +1008,37 @@ def _run_workers(
9911008

9921009
# Get the results of the ray workers.
9931010
if self.workers:
994-
ray_worker_outputs = ray.get(ray_worker_outputs)
1011+
if use_ray_compiled_dag:
1012+
try:
1013+
ray_worker_outputs = [
1014+
pickle.loads(chan.begin_read())
1015+
for chan in output_channels
1016+
]
1017+
finally:
1018+
# Has to call end_read in order to reuse the DAG.
1019+
for chan in output_channels:
1020+
chan.end_read()
1021+
else:
1022+
ray_worker_outputs = ray.get(ray_worker_outputs)
9951023

9961024
return [driver_worker_output] + ray_worker_outputs
1025+
1026+
def _compiled_ray_dag(self):
1027+
import pkg_resources
1028+
required_version = "2.9"
1029+
current_version = pkg_resources.get_distribution("ray").version
1030+
if current_version < required_version:
1031+
raise ValueError(f"Ray version {required_version} or greater is "
1032+
f"required, but found {current_version}")
1033+
1034+
from ray.dag import MultiOutputNode, InputNode
1035+
assert self.parallel_config.worker_use_ray
1036+
1037+
# Right now, compiled DAG requires at least 1 arg. We send
1038+
# a dummy value for now. It will be fixed soon.
1039+
with InputNode() as input_data:
1040+
forward_dag = MultiOutputNode([
1041+
worker.execute_model_compiled_dag_remote.bind(input_data)
1042+
for worker in self.workers
1043+
])
1044+
return forward_dag.experimental_compile()

vllm/engine/ray_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pickle
2+
13
from typing import Optional, List, Tuple, TYPE_CHECKING
24

35
from vllm.config import ParallelConfig
@@ -18,6 +20,11 @@ def __init__(self, init_cached_hf_modules=False) -> None:
1820
from transformers.dynamic_module_utils import init_hf_modules
1921
init_hf_modules()
2022
self.worker = None
23+
# Since the compiled DAG runs a main execution
24+
# in a different thread that calls cuda.set_device.
25+
# The flag indicates is set_device is called on
26+
# that thread.
27+
self.compiled_dag_cuda_device_set = False
2128

2229
def init_worker(self, worker_init_fn):
2330
self.worker = worker_init_fn()
@@ -40,6 +47,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
4047
def set_cuda_visible_devices(self, device_ids) -> None:
4148
set_cuda_visible_devices(device_ids)
4249

50+
def execute_model_compiled_dag_remote(self, ignored):
51+
"""Used only when compiled DAG is enabled."""
52+
import torch
53+
if not self.compiled_dag_cuda_device_set:
54+
torch.cuda.set_device(self.worker.device)
55+
self.compiled_dag_cuda_device_set = True
56+
57+
output = self.worker.execute_model()
58+
output = pickle.dumps(output)
59+
return output
60+
4361
except ImportError as e:
4462
logger.warning(f"Failed to import Ray with {e!r}. "
4563
"For distributed inference, please install Ray with "

0 commit comments

Comments
 (0)