2
2
from collections import defaultdict
3
3
import os
4
4
import time
5
+ import pickle
5
6
from typing import (TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Tuple ,
6
7
Union )
7
8
30
31
logger = init_logger (__name__ )
31
32
_LOCAL_LOGGING_INTERVAL_SEC = 5
32
33
34
+ # If the env var is set, it uses the Ray's compiled DAG API
35
+ # which optimizes the control plane overhead.
36
+ # Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
37
+ USE_RAY_COMPILED_DAG = bool (os .getenv ("VLLM_USE_RAY_COMPILED_DAG" , 0 ))
38
+
33
39
34
40
class LLMEngine :
35
41
"""An LLM engine that receives requests and generates texts.
@@ -125,6 +131,10 @@ def __init__(
125
131
self .stat_logger = StatLogger (
126
132
local_interval = _LOCAL_LOGGING_INTERVAL_SEC )
127
133
134
+ self .forward_dag = None
135
+ if USE_RAY_COMPILED_DAG :
136
+ self .forward_dag = self ._compiled_ray_dag ()
137
+
128
138
def get_tokenizer_for_seq (self , sequence : Sequence ):
129
139
return self .tokenizer .get_lora_tokenizer (sequence .lora_request )
130
140
@@ -807,7 +817,8 @@ def step(self) -> List[RequestOutput]:
807
817
"blocks_to_swap_in" : scheduler_outputs .blocks_to_swap_in ,
808
818
"blocks_to_swap_out" : scheduler_outputs .blocks_to_swap_out ,
809
819
"blocks_to_copy" : scheduler_outputs .blocks_to_copy ,
810
- })
820
+ },
821
+ use_ray_compiled_dag = USE_RAY_COMPILED_DAG )
811
822
812
823
# Only the driver worker returns the sampling results.
813
824
output = all_outputs [0 ]
@@ -967,6 +978,7 @@ def _run_workers(
967
978
driver_args : Optional [List [Any ]] = None ,
968
979
driver_kwargs : Optional [Dict [str , Any ]] = None ,
969
980
max_concurrent_workers : Optional [int ] = None ,
981
+ use_ray_compiled_dag : bool = False ,
970
982
** kwargs ,
971
983
) -> Any :
972
984
"""Runs the given method on all workers."""
@@ -975,11 +987,16 @@ def _run_workers(
975
987
raise NotImplementedError (
976
988
"max_concurrent_workers is not supported yet." )
977
989
978
- # Start the ray workers first.
979
- ray_worker_outputs = [
980
- worker .execute_method .remote (method , * args , ** kwargs )
981
- for worker in self .workers
982
- ]
990
+ if use_ray_compiled_dag :
991
+ # Right now, compiled DAG can only accept a single
992
+ # input. TODO(sang): Fix it.
993
+ output_channels = self .forward_dag .execute (1 )
994
+ else :
995
+ # Start the ray workers first.
996
+ ray_worker_outputs = [
997
+ worker .execute_method .remote (method , * args , ** kwargs )
998
+ for worker in self .workers
999
+ ]
983
1000
984
1001
if driver_args is None :
985
1002
driver_args = args
@@ -992,6 +1009,37 @@ def _run_workers(
992
1009
993
1010
# Get the results of the ray workers.
994
1011
if self .workers :
995
- ray_worker_outputs = ray .get (ray_worker_outputs )
1012
+ if use_ray_compiled_dag :
1013
+ try :
1014
+ ray_worker_outputs = [
1015
+ pickle .loads (chan .begin_read ())
1016
+ for chan in output_channels
1017
+ ]
1018
+ finally :
1019
+ # Has to call end_read in order to reuse the DAG.
1020
+ for chan in output_channels :
1021
+ chan .end_read ()
1022
+ else :
1023
+ ray_worker_outputs = ray .get (ray_worker_outputs )
996
1024
997
1025
return [driver_worker_output ] + ray_worker_outputs
1026
+
1027
+ def _compiled_ray_dag (self ):
1028
+ import pkg_resources
1029
+ required_version = "2.9"
1030
+ current_version = pkg_resources .get_distribution ("ray" ).version
1031
+ if current_version < required_version :
1032
+ raise ValueError (f"Ray version { required_version } or greater is "
1033
+ f"required, but found { current_version } " )
1034
+
1035
+ from ray .dag import MultiOutputNode , InputNode
1036
+ assert self .parallel_config .worker_use_ray
1037
+
1038
+ # Right now, compiled DAG requires at least 1 arg. We send
1039
+ # a dummy value for now. It will be fixed soon.
1040
+ with InputNode () as input_data :
1041
+ forward_dag = MultiOutputNode ([
1042
+ worker .execute_model_compiled_dag_remote .bind (input_data )
1043
+ for worker in self .workers
1044
+ ])
1045
+ return forward_dag .experimental_compile ()
0 commit comments