diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 1edb19c550010..1e31ff7373031 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,12 +25,8 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] - async def encode_request_async( - self, - *args, - **kwargs, - ): - return [1] + async def encode_request_async(self, *args, **kwargs): + pass def generate(self, request_id): self.request_id = request_id @@ -43,13 +39,16 @@ def add_request(self, **kwargs): self.add_request_calls += 1 async def add_request_async(self, **kwargs): - del kwargs # Unused self.add_request_calls += 1 + return def abort_request(self, request_id): del request_id # Unused self.abort_request_calls += 1 + def has_unfinished_requests(self): + return self.request_id is not None + class MockAsyncLLMEngine(AsyncLLMEngine): @@ -72,20 +71,21 @@ async def test_new_requests_event(): await engine.add_request("2", "", None) engine.engine.generate("2") await asyncio.sleep(0) - assert engine.engine.add_request_calls == 2 - assert engine.engine.step_calls == 2 await asyncio.sleep(0) - assert engine.engine.step_calls == 3 + assert engine.engine.add_request_calls == 2 + assert engine.engine.step_calls >= 2 + await asyncio.sleep(0.001) + assert engine.engine.step_calls >= 3 engine.engine.stop_generating() - await asyncio.sleep(0) - assert engine.engine.step_calls == 4 - await asyncio.sleep(0) - assert engine.engine.step_calls == 4 + await asyncio.sleep(0.001) + old_step_calls = engine.engine.step_calls + await asyncio.sleep(0.001) + assert engine.engine.step_calls == old_step_calls await engine.add_request("3", "", None) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 - assert engine.engine.step_calls == 5 + assert engine.engine.step_calls == old_step_calls + 1 await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 - assert engine.engine.step_calls == 5 + assert engine.engine.step_calls == old_step_calls + 1 diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index 4043558bae919..7b1f4a9e1eb2f 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -4,25 +4,14 @@ from vllm.outputs import RequestOutput -class DummyEvent: - - def __init__(self): - self.flag = False - - def set(self): - self.flag = True - - def clear(self): - self.flag = False - - -def test_request_tracker(): +@pytest.mark.asyncio +async def test_request_tracker(): tracker = RequestTracker() - tracker.new_requests_event = DummyEvent() stream_1 = tracker.add_request("1") - assert tracker.new_requests_event.flag + assert tracker.new_requests_event.is_set() + await tracker.wait_for_new_requests() new, finished = tracker.get_new_and_finished_requests() - assert not tracker.new_requests_event.flag + assert not tracker.new_requests_event.is_set() assert len(new) == 1 assert new[0]["request_id"] == "1" assert not finished @@ -30,9 +19,10 @@ def test_request_tracker(): stream_2 = tracker.add_request("2") stream_3 = tracker.add_request("3") - assert tracker.new_requests_event.flag + assert tracker.new_requests_event.is_set() + await tracker.wait_for_new_requests() new, finished = tracker.get_new_and_finished_requests() - assert not tracker.new_requests_event.flag + assert not tracker.new_requests_event.is_set() assert len(new) == 2 assert new[0]["request_id"] == "2" assert new[1]["request_id"] == "3" @@ -43,7 +33,7 @@ def test_request_tracker(): # request_ids must be unique with pytest.raises(KeyError): tracker.add_request("1") - assert not tracker.new_requests_event.flag + assert not tracker.new_requests_event.is_set() tracker.abort_request("1") new, finished = tracker.get_new_and_finished_requests() @@ -54,7 +44,8 @@ def test_request_tracker(): stream_4 = tracker.add_request("4") tracker.abort_request("4") - assert tracker.new_requests_event.flag + assert tracker.new_requests_event.is_set() + await tracker.wait_for_new_requests() new, finished = tracker.get_new_and_finished_requests() assert len(finished) == 1 assert "4" in finished @@ -62,11 +53,12 @@ def test_request_tracker(): assert stream_4.finished stream_5 = tracker.add_request("5") - assert tracker.new_requests_event.flag + assert tracker.new_requests_event.is_set() tracker.process_request_output( - RequestOutput("2", "output", [], [], [], bool(finished))) + RequestOutput("2", "output", [], [], [], finished=True)) + await tracker.wait_for_new_requests() new, finished = tracker.get_new_and_finished_requests() - assert not tracker.new_requests_event.flag + assert not tracker.new_requests_event.is_set() assert len(finished) == 1 assert "2" in finished assert len(new) == 1 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index df66139fddcd1..65ab0c0634176 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,9 @@ import asyncio +import os import time from functools import partial from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, - Union, AsyncIterator) + Union, AsyncIterator, Callable) from vllm.lora.request import LoRARequest from vllm.config import ModelConfig @@ -14,28 +15,31 @@ from vllm.sampling_params import SamplingParams logger = init_logger(__name__) +ENGINE_ITERATION_TIMEOUT_S = int( + os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")) class AsyncEngineDeadError(RuntimeError): pass -def _raise_exception_on_finish(task: asyncio.Task, - request_tracker: "RequestTracker") -> None: +def _raise_exception_on_finish( + task: asyncio.Task, error_callback: Callable[[Exception], + None]) -> None: msg = ("Task finished unexpectedly. This should never happen! " "Please open an issue on Github.") + + exception = None try: - try: - task.result() - except asyncio.CancelledError: - return - except Exception as exc: - raise AsyncEngineDeadError( - msg + " See stack trace above for the actual cause.") from exc + task.result() + # NOTE: This will be thrown if task exits normally (which it should not) raise AsyncEngineDeadError(msg) - except Exception as exc: - request_tracker.propagate_exception(exc) - raise exc + except Exception as e: + exception = e + logger.error("Engine background task failed", exc_info=e) + error_callback(exception) + raise AsyncEngineDeadError( + msg + " See stack trace above for the actual cause.") from e class AsyncStream: @@ -78,13 +82,13 @@ def __init__(self) -> None: self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() - self.new_requests_event = None + self.new_requests_event = asyncio.Event() def __contains__(self, item): return item in self._request_streams - def init_event(self): - self.new_requests_event = asyncio.Event() + def __len__(self) -> int: + return len(self._request_streams) def propagate_exception(self, exc: Exception, @@ -93,9 +97,11 @@ def propagate_exception(self, (all if request_id is None).""" if request_id is not None: self._request_streams[request_id].put(exc) + self.abort_request(request_id) else: - for stream in self._request_streams.values(): + for rid, stream in self._request_streams.items(): stream.put(exc) + self.abort_request(rid) def process_request_output(self, request_output: RequestOutput, @@ -172,12 +178,15 @@ def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: self._request_streams[stream.request_id] = stream new_requests.append(new_request) - self.new_requests_event.clear() - return new_requests, finished_requests async def wait_for_new_requests(self): - await self.new_requests_event.wait() + if not self.has_new_requests(): + await self.new_requests_event.wait() + self.new_requests_event.clear() + + def has_new_requests(self): + return not self._new_requests.empty() class _AsyncLLMEngine(LLMEngine): @@ -285,6 +294,10 @@ async def _run_workers_async( all_outputs = await asyncio.gather(*coros) return all_outputs + async def check_health_async(self): + """Raises an error if engine is unhealthy.""" + self._check_if_any_actor_is_dead() + class AsyncLLMEngine: """An asynchronous wrapper for LLMEngine. @@ -335,27 +348,48 @@ def __init__(self, # collected self._background_loop_unshielded = None self.start_engine_loop = start_engine_loop - self._request_tracker = RequestTracker() + self._request_tracker: Optional[RequestTracker] = None + self._errored_with: Optional[BaseException] = None @property def is_running(self) -> bool: return (self.background_loop is not None - and not self.background_loop.done()) + and not self._background_loop_unshielded.done()) + + @property + def is_stopped(self) -> bool: + return self.errored or (self.background_loop is not None + and self._background_loop_unshielded.done()) + + @property + def errored(self) -> bool: + return self._errored_with is not None + + def set_errored(self, exc: Exception) -> None: + self._errored_with = exc + + def _error_callback(self, exc: Exception) -> None: + self.set_errored(exc) + self._request_tracker.propagate_exception(exc) def get_tokenizer(self): return self.engine.tokenizer.tokenizer def start_background_loop(self) -> None: """Start the background loop.""" + if self.errored: + raise AsyncEngineDeadError( + "Background loop has errored already.") from self._errored_with if self.is_running: raise RuntimeError("Background loop is already running.") - self._request_tracker.init_event() + # Initialize the RequestTracker here so it uses the right event loop. + self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( ).create_task(self.run_engine_loop()) self._background_loop_unshielded.add_done_callback( partial(_raise_exception_on_finish, - request_tracker=self._request_tracker)) + error_callback=self._error_callback)) self.background_loop = asyncio.shield(self._background_loop_unshielded) def _init_engine(self, *args, @@ -423,12 +457,23 @@ async def _engine_abort(self, request_ids: Iterable[str]): self.engine.abort_request(request_ids) async def run_engine_loop(self): - # Initialize the RequestTracker here so it uses the right event loop. has_requests_in_progress = False while True: if not has_requests_in_progress: + logger.debug("Waiting for new requests...") await self._request_tracker.wait_for_new_requests() - has_requests_in_progress = await self.engine_step() + logger.debug("Got new requests!") + + # Abort if iteration takes too long due to unrecoverable errors + # (eg. NCCL timeouts). + try: + has_requests_in_progress = await asyncio.wait_for( + self.engine_step(), ENGINE_ITERATION_TIMEOUT_S) + except asyncio.TimeoutError as exc: + logger.error( + "Engine iteration timed out. This should never happen!") + self.set_errored(exc) + raise await asyncio.sleep(0) async def add_request( @@ -647,3 +692,19 @@ async def do_log_stats(self) -> None: await self.engine.do_log_stats.remote() else: self.engine.do_log_stats() + + async def check_health(self): + """Raises an error if engine is unhealthy.""" + t = time.perf_counter() + logger.debug("Starting health check...") + if self.is_stopped: + raise AsyncEngineDeadError("Background loop is stopped.") + + if self.engine_use_ray: + try: + await self.engine.check_health.remote() + except ray.exceptions.RayActorError as e: + raise RuntimeError("Engine is dead.") from e + else: + await self.engine.check_health_async() + logger.debug(f"Health check took {time.perf_counter()-t}s") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 703756996b7f7..1f518cbf39b21 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1119,3 +1119,23 @@ def _compiled_ray_dag(self): for worker in self.workers ]) return forward_dag.experimental_compile() + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + self._check_if_any_actor_is_dead() + + def _check_if_any_actor_is_dead(self): + if not self.parallel_config.worker_use_ray: + return + + if not self.workers: + return + + dead_actors = [] + for actor in self.workers: + actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access + if actor_state["State"] == "DEAD": + dead_actors.append(actor) + if dead_actors: + raise RuntimeError("At least one Worker is dead. " + f"Dead Workers: {dead_actors}. ")