Skip to content

Commit

Permalink
Add health check, make async Engine more robust (vllm-project#3015)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
  • Loading branch information
Yard1 and zhuohan123 authored Mar 4, 2024
1 parent 22de452 commit ff578ca
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 65 deletions.
32 changes: 16 additions & 16 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,8 @@ async def step_async(self):
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []

async def encode_request_async(
self,
*args,
**kwargs,
):
return [1]
async def encode_request_async(self, *args, **kwargs):
pass

def generate(self, request_id):
self.request_id = request_id
Expand All @@ -43,13 +39,16 @@ def add_request(self, **kwargs):
self.add_request_calls += 1

async def add_request_async(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
return

def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1

def has_unfinished_requests(self):
return self.request_id is not None


class MockAsyncLLMEngine(AsyncLLMEngine):

Expand All @@ -72,20 +71,21 @@ async def test_new_requests_event():
await engine.add_request("2", "", None)
engine.engine.generate("2")
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001)
assert engine.engine.step_calls >= 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0.001)
old_step_calls = engine.engine.step_calls
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls

await engine.add_request("3", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
assert engine.engine.step_calls == old_step_calls + 1
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
assert engine.engine.step_calls == old_step_calls + 1
38 changes: 15 additions & 23 deletions tests/async_engine/test_request_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,25 @@
from vllm.outputs import RequestOutput


class DummyEvent:

def __init__(self):
self.flag = False

def set(self):
self.flag = True

def clear(self):
self.flag = False


def test_request_tracker():
@pytest.mark.asyncio
async def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
assert not stream_1.finished

stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
Expand All @@ -43,7 +33,7 @@ def test_request_tracker():
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()

tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
Expand All @@ -54,19 +44,21 @@ def test_request_tracker():

stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
assert not new
assert stream_4.finished

stream_5 = tracker.add_request("5")
assert tracker.new_requests_event.flag
assert tracker.new_requests_event.is_set()
tracker.process_request_output(
RequestOutput("2", "output", [], [], [], bool(finished)))
RequestOutput("2", "output", [], [], [], finished=True))
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert not tracker.new_requests_event.is_set()
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1
Expand Down
113 changes: 87 additions & 26 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import os
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)
Union, AsyncIterator, Callable)

from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
Expand All @@ -14,28 +15,31 @@
from vllm.sampling_params import SamplingParams

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))


class AsyncEngineDeadError(RuntimeError):
pass


def _raise_exception_on_finish(task: asyncio.Task,
request_tracker: "RequestTracker") -> None:
def _raise_exception_on_finish(
task: asyncio.Task, error_callback: Callable[[Exception],
None]) -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")

exception = None
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
task.result()
# NOTE: This will be thrown if task exits normally (which it should not)
raise AsyncEngineDeadError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
except Exception as e:
exception = e
logger.error("Engine background task failed", exc_info=e)
error_callback(exception)
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from e


class AsyncStream:
Expand Down Expand Up @@ -78,13 +82,13 @@ def __init__(self) -> None:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = None
self.new_requests_event = asyncio.Event()

def __contains__(self, item):
return item in self._request_streams

def init_event(self):
self.new_requests_event = asyncio.Event()
def __len__(self) -> int:
return len(self._request_streams)

def propagate_exception(self,
exc: Exception,
Expand All @@ -93,9 +97,11 @@ def propagate_exception(self,
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else:
for stream in self._request_streams.values():
for rid, stream in self._request_streams.items():
stream.put(exc)
self.abort_request(rid)

def process_request_output(self,
request_output: RequestOutput,
Expand Down Expand Up @@ -172,12 +178,15 @@ def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)

self.new_requests_event.clear()

return new_requests, finished_requests

async def wait_for_new_requests(self):
await self.new_requests_event.wait()
if not self.has_new_requests():
await self.new_requests_event.wait()
self.new_requests_event.clear()

def has_new_requests(self):
return not self._new_requests.empty()


class _AsyncLLMEngine(LLMEngine):
Expand Down Expand Up @@ -285,6 +294,10 @@ async def _run_workers_async(
all_outputs = await asyncio.gather(*coros)
return all_outputs

async def check_health_async(self):
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()


class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
Expand Down Expand Up @@ -335,27 +348,48 @@ def __init__(self,
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None

@property
def is_running(self) -> bool:
return (self.background_loop is not None
and not self.background_loop.done())
and not self._background_loop_unshielded.done())

@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
and self._background_loop_unshielded.done())

@property
def errored(self) -> bool:
return self._errored_with is not None

def set_errored(self, exc: Exception) -> None:
self._errored_with = exc

def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer

def start_background_loop(self) -> None:
"""Start the background loop."""
if self.errored:
raise AsyncEngineDeadError(
"Background loop has errored already.") from self._errored_with
if self.is_running:
raise RuntimeError("Background loop is already running.")
self._request_tracker.init_event()
# Initialize the RequestTracker here so it uses the right event loop.
self._request_tracker = RequestTracker()

self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self._request_tracker))
error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)

def _init_engine(self, *args,
Expand Down Expand Up @@ -423,12 +457,23 @@ async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids)

async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True:
if not has_requests_in_progress:
logger.debug("Waiting for new requests...")
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
logger.debug("Got new requests!")

# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
has_requests_in_progress = await asyncio.wait_for(
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
raise
await asyncio.sleep(0)

async def add_request(
Expand Down Expand Up @@ -647,3 +692,19 @@ async def do_log_stats(self) -> None:
await self.engine.do_log_stats.remote()
else:
self.engine.do_log_stats()

async def check_health(self):
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")

if self.engine_use_ray:
try:
await self.engine.check_health.remote()
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
logger.debug(f"Health check took {time.perf_counter()-t}s")
20 changes: 20 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,23 @@ def _compiled_ray_dag(self):
for worker in self.workers
])
return forward_dag.experimental_compile()

def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()

def _check_if_any_actor_is_dead(self):
if not self.parallel_config.worker_use_ray:
return

if not self.workers:
return

dead_actors = []
for actor in self.workers:
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
if actor_state["State"] == "DEAD":
dead_actors.append(actor)
if dead_actors:
raise RuntimeError("At least one Worker is dead. "
f"Dead Workers: {dead_actors}. ")

0 comments on commit ff578ca

Please sign in to comment.