Skip to content

Commit 67882db

Browse files
authored
[Core] Add fault tolerance for RayTokenizerGroupPool (#5748)
1 parent 7b99314 commit 67882db

File tree

5 files changed

+195
-24
lines changed

5 files changed

+195
-24
lines changed

tests/tokenization/test_tokenizer_group.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
import os
3+
import sys
4+
from typing import List, Optional
35
from unittest.mock import patch
46

57
import pytest
@@ -100,3 +102,100 @@ class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
100102
max_num_seqs=1,
101103
max_input_length=None)
102104
tokenizer_pool.ping()
105+
106+
107+
@pytest.mark.asyncio
108+
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
109+
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
110+
"""Test that Ray tokenizer pool group can recover from failures and
111+
if that's not possible, mark itself as unhealthy."""
112+
113+
class FailingTokenizerGroup(TokenizerGroup):
114+
115+
def __init__(self,
116+
*args,
117+
fail_at: Optional[List[int]] = None,
118+
**kwargs):
119+
super().__init__(*args, **kwargs)
120+
self.i = 0
121+
self.fail_at = fail_at or []
122+
123+
def encode(self, *args, **kwargs):
124+
self.i += 1
125+
if self.i in self.fail_at:
126+
sys.exit(1)
127+
return super().encode(*args, **kwargs)
128+
129+
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
130+
_worker_cls = FailingTokenizerGroup
131+
132+
# Fail at first iteration
133+
fail_at = [1]
134+
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
135+
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
136+
tokenizer_pool_config,
137+
tokenizer_id="gpt2",
138+
enable_lora=False,
139+
max_num_seqs=1,
140+
max_input_length=None,
141+
fail_at=fail_at)
142+
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
143+
144+
# Modify fail at to not fail at all (will be re-read when actor is
145+
# re-initialized).
146+
fail_at[0] = 1000
147+
148+
# We should recover successfully.
149+
await tokenizer_group_pool.encode_async(request_id="1",
150+
prompt="prompt",
151+
lora_request=None)
152+
await tokenizer_group_pool.encode_async(request_id="1",
153+
prompt="prompt",
154+
lora_request=None)
155+
156+
# Check that we have a new actor
157+
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
158+
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
159+
160+
# Fail at first iteration
161+
fail_at = [1]
162+
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
163+
tokenizer_pool_config,
164+
tokenizer_id="gpt2",
165+
enable_lora=False,
166+
max_num_seqs=1,
167+
max_input_length=None,
168+
fail_at=fail_at)
169+
170+
# We should fail after re-initialization.
171+
with pytest.raises(RuntimeError):
172+
await tokenizer_group_pool.encode_async(request_id="1",
173+
prompt="prompt",
174+
lora_request=None)
175+
176+
# check_health should raise the same thing
177+
with pytest.raises(RuntimeError):
178+
tokenizer_group_pool.check_health()
179+
180+
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
181+
# cause a re-initialization.
182+
fail_at = []
183+
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
184+
tokenizer_pool_config,
185+
tokenizer_id="gpt2",
186+
enable_lora=False,
187+
max_num_seqs=1,
188+
max_input_length=2,
189+
fail_at=fail_at)
190+
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
191+
192+
# Prompt too long error
193+
with pytest.raises(ValueError):
194+
await tokenizer_group_pool.encode_async(request_id="1",
195+
prompt="prompt" * 100,
196+
lora_request=None)
197+
await tokenizer_group_pool.encode_async(request_id="1",
198+
prompt="prompt",
199+
lora_request=None)
200+
# Actors should stay the same.
201+
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors

vllm/engine/async_llm_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ async def add_request_async(
310310
)
311311

312312
async def check_health_async(self) -> None:
313+
if self.tokenizer:
314+
self.tokenizer.check_health()
313315
self.model_executor.check_health()
314316

315317

vllm/engine/llm_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,8 @@ def pin_lora(self, lora_id: int) -> bool:
10131013
return self.model_executor.pin_lora(lora_id)
10141014

10151015
def check_health(self) -> None:
1016+
if self.tokenizer:
1017+
self.tokenizer.check_health()
10161018
self.model_executor.check_health()
10171019

10181020
def is_tracing_enabled(self) -> bool:

vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,7 @@ async def get_lora_tokenizer_async(
5353
) -> "PreTrainedTokenizer":
5454
"""Get a tokenizer for a LoRA request."""
5555
pass
56+
57+
def check_health(self):
58+
"""Raise exception if the tokenizer group is unhealthy."""
59+
return

vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@
22
import os
33
from typing import List, Optional
44

5+
from ray.exceptions import ActorDiedError
56
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
67
from transformers import PreTrainedTokenizer
78

89
from vllm.config import TokenizerPoolConfig
910
from vllm.executor.ray_utils import ray
11+
from vllm.logger import init_logger
1012
from vllm.lora.request import LoRARequest
1113
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
1214
BaseTokenizerGroup)
1315
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
1416
TokenizerGroup)
1517

18+
logger = init_logger(__name__)
19+
1620

1721
class RayTokenizerGroupPool(BaseTokenizerGroup):
1822
"""A Ray-based pool of TokenizerGroups for async tokenization."""
@@ -46,24 +50,28 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
4650
ray_actor_options: dict, **tokenizer_config):
4751
# Store a local copy of the TokenizerGroup for quick access
4852
# to underlying HF tokenizers.
53+
self._tokenizer_config = {
54+
"tokenizer_id": tokenizer_id,
55+
"enable_lora": enable_lora,
56+
"max_num_seqs": max_num_seqs,
57+
"max_input_length": max_input_length,
58+
**tokenizer_config
59+
}
4960
self._local_tokenizer_group = self._worker_cls(
50-
tokenizer_id=tokenizer_id,
51-
enable_lora=enable_lora,
52-
max_num_seqs=max_num_seqs,
53-
max_input_length=max_input_length,
54-
**tokenizer_config,
55-
)
56-
57-
ray_tokenizer_group_cls = ray.remote(
61+
**self._tokenizer_config, )
62+
63+
self._ray_tokenizer_group_cls = ray.remote(
5864
self._worker_cls).options(**ray_actor_options)
59-
self.tokenizer_actors = [
60-
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
61-
max_num_seqs, max_input_length,
62-
**tokenizer_config)
63-
for _ in range(num_actors)
64-
]
65+
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
6566
self._idle_actors: Optional[asyncio.Queue] = None
6667

68+
# If set, actor is unhealthy. Will reraise on the next
69+
# check_health call.
70+
self._exception: Optional[ActorDiedError] = None
71+
72+
def _init_actor(self) -> ray.ObjectRef:
73+
return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
74+
6775
@property
6876
def pool_size(self) -> int:
6977
return len(self.tokenizer_actors)
@@ -78,6 +86,22 @@ def _ensure_queue_initialized(self):
7886
for actor in self.tokenizer_actors:
7987
self._idle_actors.put_nowait(actor)
8088

89+
def _finalize_encode(self, actor: ray.ObjectRef,
90+
original_actor: ray.ObjectRef, actor_is_alive: bool):
91+
assert self._idle_actors is not None
92+
# Cleanup the dead actor.
93+
if not actor_is_alive or original_actor is not actor:
94+
self.tokenizer_actors.remove(original_actor)
95+
if actor_is_alive:
96+
# Put the actor back in the queue.
97+
# This is done in a finally block to ensure that the actor is
98+
# always put back in the queue, even if an exception/cancellation
99+
# is raised.
100+
self._idle_actors.put_nowait(actor)
101+
# Add back the new actor.
102+
if original_actor is not actor:
103+
self.tokenizer_actors.append(actor)
104+
81105
def encode(self,
82106
prompt: str,
83107
request_id: Optional[str] = None,
@@ -88,23 +112,41 @@ def encode(self,
88112
The actor is then put back in the queue for future use.
89113
This is blocking.
90114
"""
115+
self.check_health()
91116
self._ensure_queue_initialized()
92117
assert self._idle_actors is not None
93118

94119
if self._idle_actors.empty():
95120
raise RuntimeError("No idle actors available.")
96121
actor = self._idle_actors.get_nowait()
122+
actor_is_alive = True
123+
original_actor = actor
97124
try:
98125
ret = ray.get(
99126
actor.encode.remote(request_id=request_id,
100127
prompt=prompt,
101128
lora_request=lora_request))
129+
except ActorDiedError as e:
130+
# If the actor is dead, we first try to reinitialize it.
131+
logger.warning("%s died with ActorDiedError, reinitializing.",
132+
actor,
133+
exc_info=e)
134+
actor = self._init_actor()
135+
try:
136+
ret = ray.get(
137+
actor.encode.remote(request_id=request_id,
138+
prompt=prompt,
139+
lora_request=lora_request))
140+
except ActorDiedError as e:
141+
logger.error(
142+
"%s died for second time in a row, marking "
143+
"RayTokenizerGroupPool as unhealthy.", actor)
144+
actor_is_alive = False
145+
if not self._exception:
146+
self._exception = e
147+
self.check_health()
102148
finally:
103-
# Put the actor back in the queue.
104-
# This is done in a finally block to ensure that the actor is
105-
# always put back in the queue, even if an exception/cancellation
106-
# is raised.
107-
self._idle_actors.put_nowait(actor)
149+
self._finalize_encode(actor, original_actor, actor_is_alive)
108150
return ret
109151

110152
async def encode_async(
@@ -120,20 +162,37 @@ async def encode_async(
120162
The actor is then put back in the queue for future use.
121163
This is non-blocking.
122164
"""
165+
self.check_health()
123166
self._ensure_queue_initialized()
124167
assert self._idle_actors is not None
125168

126169
actor = await self._idle_actors.get()
170+
actor_is_alive = True
171+
original_actor = actor
127172
try:
128173
ret = await actor.encode.remote(request_id=request_id,
129174
prompt=prompt,
130175
lora_request=lora_request)
176+
except ActorDiedError as e:
177+
# If the actor is dead, we first try to reinitialize it.
178+
logger.warning("%s died with ActorDiedError, reinitializing.",
179+
actor,
180+
exc_info=e)
181+
actor = self._init_actor()
182+
try:
183+
ret = await actor.encode.remote(request_id=request_id,
184+
prompt=prompt,
185+
lora_request=lora_request)
186+
except ActorDiedError as e:
187+
logger.error(
188+
"%s died for second time in a row, marking "
189+
"RayTokenizerGroupPool as unhealthy.", actor)
190+
actor_is_alive = False
191+
if not self._exception:
192+
self._exception = e
193+
self.check_health()
131194
finally:
132-
# Put the actor back in the queue.
133-
# This is done in a finally block to ensure that the actor is
134-
# always put back in the queue, even if an exception/cancellation
135-
# is raised.
136-
self._idle_actors.put_nowait(actor)
195+
self._finalize_encode(actor, original_actor, actor_is_alive)
137196
return ret
138197

139198
def get_max_input_len(self,
@@ -155,6 +214,11 @@ async def get_lora_tokenizer_async(
155214
return await self._local_tokenizer_group.get_lora_tokenizer_async(
156215
lora_request)
157216

217+
def check_health(self):
218+
if self._exception:
219+
raise RuntimeError(
220+
"TokenizerGroupPool is unhealthy.") from self._exception
221+
158222

159223
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
160224
"""Copy over all current process environment variables to the runtime_env.

0 commit comments

Comments
 (0)