2
2
from collections import defaultdict
3
3
import os
4
4
import time
5
+ import pickle
5
6
from typing import (TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Tuple ,
6
7
Union )
7
8
30
31
logger = init_logger (__name__ )
31
32
_LOCAL_LOGGING_INTERVAL_SEC = 5
32
33
34
+ # If the env var is set, it uses the Ray's compiled DAG API
35
+ # which optimizes the control plane overhead.
36
+ # Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
37
+ USE_RAY_COMPILED_DAG = bool (os .getenv ("VLLM_USE_RAY_COMPILED_DAG" , 0 ))
38
+
33
39
34
40
class LLMEngine :
35
41
"""An LLM engine that receives requests and generates texts.
@@ -124,6 +130,10 @@ def __init__(
124
130
self .stat_logger = StatLogger (
125
131
local_interval = _LOCAL_LOGGING_INTERVAL_SEC )
126
132
133
+ self .forward_dag = None
134
+ if USE_RAY_COMPILED_DAG :
135
+ self .forward_dag = self ._compiled_ray_dag ()
136
+
127
137
def get_tokenizer_for_seq (self , sequence : Sequence ):
128
138
return self .tokenizer .get_lora_tokenizer (sequence .lora_request )
129
139
@@ -806,7 +816,8 @@ def step(self) -> List[RequestOutput]:
806
816
"blocks_to_swap_in" : scheduler_outputs .blocks_to_swap_in ,
807
817
"blocks_to_swap_out" : scheduler_outputs .blocks_to_swap_out ,
808
818
"blocks_to_copy" : scheduler_outputs .blocks_to_copy ,
809
- })
819
+ },
820
+ use_ray_compiled_dag = USE_RAY_COMPILED_DAG )
810
821
811
822
# Only the driver worker returns the sampling results.
812
823
output = all_outputs [0 ]
@@ -966,6 +977,7 @@ def _run_workers(
966
977
driver_args : Optional [List [Any ]] = None ,
967
978
driver_kwargs : Optional [Dict [str , Any ]] = None ,
968
979
max_concurrent_workers : Optional [int ] = None ,
980
+ use_ray_compiled_dag : bool = False ,
969
981
** kwargs ,
970
982
) -> Any :
971
983
"""Runs the given method on all workers."""
@@ -974,11 +986,16 @@ def _run_workers(
974
986
raise NotImplementedError (
975
987
"max_concurrent_workers is not supported yet." )
976
988
977
- # Start the ray workers first.
978
- ray_worker_outputs = [
979
- worker .execute_method .remote (method , * args , ** kwargs )
980
- for worker in self .workers
981
- ]
989
+ if use_ray_compiled_dag :
990
+ # Right now, compiled DAG can only accept a single
991
+ # input. TODO(sang): Fix it.
992
+ output_channels = self .forward_dag .execute (1 )
993
+ else :
994
+ # Start the ray workers first.
995
+ ray_worker_outputs = [
996
+ worker .execute_method .remote (method , * args , ** kwargs )
997
+ for worker in self .workers
998
+ ]
982
999
983
1000
if driver_args is None :
984
1001
driver_args = args
@@ -991,6 +1008,37 @@ def _run_workers(
991
1008
992
1009
# Get the results of the ray workers.
993
1010
if self .workers :
994
- ray_worker_outputs = ray .get (ray_worker_outputs )
1011
+ if use_ray_compiled_dag :
1012
+ try :
1013
+ ray_worker_outputs = [
1014
+ pickle .loads (chan .begin_read ())
1015
+ for chan in output_channels
1016
+ ]
1017
+ finally :
1018
+ # Has to call end_read in order to reuse the DAG.
1019
+ for chan in output_channels :
1020
+ chan .end_read ()
1021
+ else :
1022
+ ray_worker_outputs = ray .get (ray_worker_outputs )
995
1023
996
1024
return [driver_worker_output ] + ray_worker_outputs
1025
+
1026
+ def _compiled_ray_dag (self ):
1027
+ import pkg_resources
1028
+ required_version = "2.9"
1029
+ current_version = pkg_resources .get_distribution ("ray" ).version
1030
+ if current_version < required_version :
1031
+ raise ValueError (f"Ray version { required_version } or greater is "
1032
+ f"required, but found { current_version } " )
1033
+
1034
+ from ray .dag import MultiOutputNode , InputNode
1035
+ assert self .parallel_config .worker_use_ray
1036
+
1037
+ # Right now, compiled DAG requires at least 1 arg. We send
1038
+ # a dummy value for now. It will be fixed soon.
1039
+ with InputNode () as input_data :
1040
+ forward_dag = MultiOutputNode ([
1041
+ worker .execute_model_compiled_dag_remote .bind (input_data )
1042
+ for worker in self .workers
1043
+ ])
1044
+ return forward_dag .experimental_compile ()
0 commit comments