2
2
import os
3
3
from typing import List , Optional
4
4
5
+ from ray .exceptions import ActorDiedError
5
6
from ray .util .scheduling_strategies import NodeAffinitySchedulingStrategy
6
7
from transformers import PreTrainedTokenizer
7
8
8
9
from vllm .config import TokenizerPoolConfig
9
10
from vllm .executor .ray_utils import ray
11
+ from vllm .logger import init_logger
10
12
from vllm .lora .request import LoRARequest
11
13
from vllm .transformers_utils .tokenizer_group .base_tokenizer_group import (
12
14
BaseTokenizerGroup )
13
15
from vllm .transformers_utils .tokenizer_group .tokenizer_group import (
14
16
TokenizerGroup )
15
17
18
+ logger = init_logger (__name__ )
19
+
16
20
17
21
class RayTokenizerGroupPool (BaseTokenizerGroup ):
18
22
"""A Ray-based pool of TokenizerGroups for async tokenization."""
@@ -46,24 +50,28 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
46
50
ray_actor_options : dict , ** tokenizer_config ):
47
51
# Store a local copy of the TokenizerGroup for quick access
48
52
# to underlying HF tokenizers.
53
+ self ._tokenizer_config = {
54
+ "tokenizer_id" : tokenizer_id ,
55
+ "enable_lora" : enable_lora ,
56
+ "max_num_seqs" : max_num_seqs ,
57
+ "max_input_length" : max_input_length ,
58
+ ** tokenizer_config
59
+ }
49
60
self ._local_tokenizer_group = self ._worker_cls (
50
- tokenizer_id = tokenizer_id ,
51
- enable_lora = enable_lora ,
52
- max_num_seqs = max_num_seqs ,
53
- max_input_length = max_input_length ,
54
- ** tokenizer_config ,
55
- )
56
-
57
- ray_tokenizer_group_cls = ray .remote (
61
+ ** self ._tokenizer_config , )
62
+
63
+ self ._ray_tokenizer_group_cls = ray .remote (
58
64
self ._worker_cls ).options (** ray_actor_options )
59
- self .tokenizer_actors = [
60
- ray_tokenizer_group_cls .remote (tokenizer_id , enable_lora ,
61
- max_num_seqs , max_input_length ,
62
- ** tokenizer_config )
63
- for _ in range (num_actors )
64
- ]
65
+ self .tokenizer_actors = [self ._init_actor () for _ in range (num_actors )]
65
66
self ._idle_actors : Optional [asyncio .Queue ] = None
66
67
68
+ # If set, actor is unhealthy. Will reraise on the next
69
+ # check_health call.
70
+ self ._exception : Optional [ActorDiedError ] = None
71
+
72
+ def _init_actor (self ) -> ray .ObjectRef :
73
+ return self ._ray_tokenizer_group_cls .remote (** self ._tokenizer_config )
74
+
67
75
@property
68
76
def pool_size (self ) -> int :
69
77
return len (self .tokenizer_actors )
@@ -78,6 +86,22 @@ def _ensure_queue_initialized(self):
78
86
for actor in self .tokenizer_actors :
79
87
self ._idle_actors .put_nowait (actor )
80
88
89
+ def _finalize_encode (self , actor : ray .ObjectRef ,
90
+ original_actor : ray .ObjectRef , actor_is_alive : bool ):
91
+ assert self ._idle_actors is not None
92
+ # Cleanup the dead actor.
93
+ if not actor_is_alive or original_actor is not actor :
94
+ self .tokenizer_actors .remove (original_actor )
95
+ if actor_is_alive :
96
+ # Put the actor back in the queue.
97
+ # This is done in a finally block to ensure that the actor is
98
+ # always put back in the queue, even if an exception/cancellation
99
+ # is raised.
100
+ self ._idle_actors .put_nowait (actor )
101
+ # Add back the new actor.
102
+ if original_actor is not actor :
103
+ self .tokenizer_actors .append (actor )
104
+
81
105
def encode (self ,
82
106
prompt : str ,
83
107
request_id : Optional [str ] = None ,
@@ -88,23 +112,41 @@ def encode(self,
88
112
The actor is then put back in the queue for future use.
89
113
This is blocking.
90
114
"""
115
+ self .check_health ()
91
116
self ._ensure_queue_initialized ()
92
117
assert self ._idle_actors is not None
93
118
94
119
if self ._idle_actors .empty ():
95
120
raise RuntimeError ("No idle actors available." )
96
121
actor = self ._idle_actors .get_nowait ()
122
+ actor_is_alive = True
123
+ original_actor = actor
97
124
try :
98
125
ret = ray .get (
99
126
actor .encode .remote (request_id = request_id ,
100
127
prompt = prompt ,
101
128
lora_request = lora_request ))
129
+ except ActorDiedError as e :
130
+ # If the actor is dead, we first try to reinitialize it.
131
+ logger .warning ("%s died with ActorDiedError, reinitializing." ,
132
+ actor ,
133
+ exc_info = e )
134
+ actor = self ._init_actor ()
135
+ try :
136
+ ret = ray .get (
137
+ actor .encode .remote (request_id = request_id ,
138
+ prompt = prompt ,
139
+ lora_request = lora_request ))
140
+ except ActorDiedError as e :
141
+ logger .error (
142
+ "%s died for second time in a row, marking "
143
+ "RayTokenizerGroupPool as unhealthy." , actor )
144
+ actor_is_alive = False
145
+ if not self ._exception :
146
+ self ._exception = e
147
+ self .check_health ()
102
148
finally :
103
- # Put the actor back in the queue.
104
- # This is done in a finally block to ensure that the actor is
105
- # always put back in the queue, even if an exception/cancellation
106
- # is raised.
107
- self ._idle_actors .put_nowait (actor )
149
+ self ._finalize_encode (actor , original_actor , actor_is_alive )
108
150
return ret
109
151
110
152
async def encode_async (
@@ -120,20 +162,37 @@ async def encode_async(
120
162
The actor is then put back in the queue for future use.
121
163
This is non-blocking.
122
164
"""
165
+ self .check_health ()
123
166
self ._ensure_queue_initialized ()
124
167
assert self ._idle_actors is not None
125
168
126
169
actor = await self ._idle_actors .get ()
170
+ actor_is_alive = True
171
+ original_actor = actor
127
172
try :
128
173
ret = await actor .encode .remote (request_id = request_id ,
129
174
prompt = prompt ,
130
175
lora_request = lora_request )
176
+ except ActorDiedError as e :
177
+ # If the actor is dead, we first try to reinitialize it.
178
+ logger .warning ("%s died with ActorDiedError, reinitializing." ,
179
+ actor ,
180
+ exc_info = e )
181
+ actor = self ._init_actor ()
182
+ try :
183
+ ret = await actor .encode .remote (request_id = request_id ,
184
+ prompt = prompt ,
185
+ lora_request = lora_request )
186
+ except ActorDiedError as e :
187
+ logger .error (
188
+ "%s died for second time in a row, marking "
189
+ "RayTokenizerGroupPool as unhealthy." , actor )
190
+ actor_is_alive = False
191
+ if not self ._exception :
192
+ self ._exception = e
193
+ self .check_health ()
131
194
finally :
132
- # Put the actor back in the queue.
133
- # This is done in a finally block to ensure that the actor is
134
- # always put back in the queue, even if an exception/cancellation
135
- # is raised.
136
- self ._idle_actors .put_nowait (actor )
195
+ self ._finalize_encode (actor , original_actor , actor_is_alive )
137
196
return ret
138
197
139
198
def get_max_input_len (self ,
@@ -155,6 +214,11 @@ async def get_lora_tokenizer_async(
155
214
return await self ._local_tokenizer_group .get_lora_tokenizer_async (
156
215
lora_request )
157
216
217
+ def check_health (self ):
218
+ if self ._exception :
219
+ raise RuntimeError (
220
+ "TokenizerGroupPool is unhealthy." ) from self ._exception
221
+
158
222
159
223
def _carry_over_env_vars_to_runtime_env (runtime_env : dict ) -> None :
160
224
"""Copy over all current process environment variables to the runtime_env.
0 commit comments