From cde9183b40a88f4210a7e965a430ae860aba5f6d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 21 Aug 2024 20:18:11 -0600 Subject: [PATCH] [Bug][Frontend] Improve ZMQ client robustness (#7443) Signed-off-by: Joe Runde --- tests/entrypoints/openai/rpc/__init__.py | 0 .../entrypoints/openai/rpc/test_zmq_client.py | 119 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 5 +- vllm/entrypoints/openai/rpc/__init__.py | 4 - vllm/entrypoints/openai/rpc/client.py | 70 +++++++---- vllm/envs.py | 6 + 6 files changed, 176 insertions(+), 28 deletions(-) create mode 100644 tests/entrypoints/openai/rpc/__init__.py create mode 100644 tests/entrypoints/openai/rpc/test_zmq_client.py diff --git a/tests/entrypoints/openai/rpc/__init__.py b/tests/entrypoints/openai/rpc/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py new file mode 100644 index 0000000000000..631d15cd03ed7 --- /dev/null +++ b/tests/entrypoints/openai/rpc/test_zmq_client.py @@ -0,0 +1,119 @@ +import asyncio +import tempfile +import unittest +import unittest.mock +import uuid + +import pytest +import pytest_asyncio + +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, + RPCClientClosedError) +from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest_asyncio.fixture(scope="function") +async def dummy_server(tmp_socket, monkeypatch): + dummy_engine = unittest.mock.AsyncMock() + + def dummy_engine_builder(*args, **kwargs): + return dummy_engine + + with monkeypatch.context() as m: + m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) + server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) + + loop = asyncio.get_running_loop() + server_task = loop.create_task(server.run_server_loop()) + + try: + yield server + finally: + server_task.cancel() + server.cleanup() + + +@pytest_asyncio.fixture(scope="function") +async def client(tmp_socket): + client = AsyncEngineRPCClient(rpc_path=tmp_socket) + # Sanity check: the server is connected + await client._wait_for_server_rpc() + + try: + yield client + finally: + client.close() + + +@pytest.mark.asyncio +async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, + client: AsyncEngineRPCClient): + with monkeypatch.context() as m: + # Make the server _not_ reply with a model config + m.setattr(dummy_server, "get_config", lambda x: None) + m.setattr(client, "_data_timeout", 10) + + # And ensure the task completes anyway + # (client.setup() invokes server.get_config()) + client_task = asyncio.get_running_loop().create_task(client.setup()) + with pytest.raises(TimeoutError, match="Server didn't reply within"): + await asyncio.wait_for(client_task, timeout=0.05) + + +@pytest.mark.asyncio +async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, + client: AsyncEngineRPCClient): + with monkeypatch.context() as m: + # Hang all abort requests + m.setattr(dummy_server, "abort", lambda x: None) + m.setattr(client, "_data_timeout", 10) + + # Ensure the client doesn't hang + client_task = asyncio.get_running_loop().create_task( + client.abort("test request id")) + with pytest.raises(TimeoutError, match="Server didn't reply within"): + await asyncio.wait_for(client_task, timeout=0.05) + + +@pytest.mark.asyncio +async def test_client_data_methods_reraise_exceptions( + monkeypatch, dummy_server, client: AsyncEngineRPCClient): + with monkeypatch.context() as m: + # Make the server raise some random exception + exception = RuntimeError("Client test exception") + + def raiser(): + raise exception + + m.setattr(dummy_server.engine, "get_model_config", raiser) + m.setattr(client, "_data_timeout", 10) + + client_task = asyncio.get_running_loop().create_task(client.setup()) + # And ensure the task completes, raising the exception + with pytest.raises(RuntimeError, match=str(exception)): + await asyncio.wait_for(client_task, timeout=0.05) + + +@pytest.mark.asyncio +async def test_client_errors_after_closing(monkeypatch, dummy_server, + client: AsyncEngineRPCClient): + + client.close() + + # Healthchecks and generate requests will fail with explicit errors + with pytest.raises(RPCClientClosedError): + await client.check_health() + with pytest.raises(RPCClientClosedError): + async for _ in client.generate(None, None, None): + pass + + # But no-ops like aborting will pass + await client.abort("test-request-id") + await client.do_log_stats() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8e8371ef1559a..603ac19d8c04b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -6,7 +6,7 @@ import re import tempfile from argparse import Namespace -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from http import HTTPStatus from typing import AsyncIterator, Optional, Set @@ -83,7 +83,8 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await async_engine_client.do_log_stats() + with suppress(Exception): + await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 571dca5f61fa4..efc7e43afdcc9 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -10,10 +10,6 @@ # Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" -# Timeouts. -VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000 -VLLM_RPC_HEALTH_TIMEOUT_MS = 10000 - # Minimum value of ZMQ.SOCKET_LIMIT to run mp. VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 1f26348c74d6d..55b92d41975ea 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,5 +1,5 @@ import asyncio -from contextlib import contextmanager +from contextlib import contextmanager, suppress from typing import Any, AsyncGenerator, Mapping, Optional from uuid import uuid4 @@ -11,13 +11,12 @@ ParallelConfig, SchedulerConfig) # yapf: disable from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_HEALTH_TIMEOUT_MS, - VLLM_RPC_SERVER_START_TIMEOUT_MS, VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) # yapf: enable +from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -32,6 +31,17 @@ INPROC_PROXY_PATH = f"inproc://{uuid4()}" +class RPCClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + class AsyncEngineRPCClient: """ RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. @@ -85,6 +95,8 @@ class AsyncEngineRPCClient: def __init__(self, rpc_path: str): self.context = zmq.asyncio.Context() + self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS + self._errored = False # Maximum number of sockets that can be opened (typically 65536). # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) @@ -143,7 +155,6 @@ async def setup(self): # Wait until server is ready. await self._wait_for_server_rpc() - self._errored = False # Get the configs. self.model_config = await self._get_model_config_rpc() @@ -170,6 +181,15 @@ def close(self): @contextmanager def to_proxy_socket(self): # Connect to the RPCServer via the proxy. + + # Raise a sensible error if the client was already closed. + # This can happen if a server shutdown is triggered but some coroutines + # are still running requests. + # There should not be a race condition with this check because we don't + # yield to the event loop between here and opening the socket. + if self.context.closed: + raise RPCClientClosedError("The ZMQ client has already shut down") + # Note that we use DEALER to enable asynchronous communication # to enable streaming. socket = self.context.socket(zmq.constants.DEALER) @@ -189,9 +209,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, # Ping RPCServer with a request. await socket.send_multipart([cloudpickle.dumps(request)]) + # Make sure the server responds + if await socket.poll(timeout=self._data_timeout) == 0: + raise TimeoutError("Server didn't reply within " + f"{self._data_timeout} ms") + # Await the data from the Server. data = cloudpickle.loads(await socket.recv()) + if isinstance(data, Exception): + # Re-raise exceptions returned by the server + raise data + if not isinstance(data, expected_type): # LoRAConfig can be None. if expected_type == LoRAConfig and data is None: @@ -208,29 +237,28 @@ async def _send_one_way_rpc_request( self, request: RPC_REQUEST_TYPE, error_message: str, - timeout: Optional[int] = None, socket: Optional[zmq.asyncio.Socket] = None): """Send one-way RPC request to trigger an action.""" async def do_rpc_call(socket: zmq.asyncio.Socket, - request: RPC_REQUEST_TYPE, - timeout=None): + request: RPC_REQUEST_TYPE): await socket.send_multipart([cloudpickle.dumps(request)]) - if timeout is not None and await socket.poll(timeout=timeout) == 0: - raise TimeoutError(f"Server didn't reply within {timeout} ms") + if await socket.poll(timeout=self._data_timeout) == 0: + raise TimeoutError("Server didn't reply within " + f"{self._data_timeout} ms") return cloudpickle.loads(await socket.recv()) # Make a new socket connection. if socket is None: with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request, timeout) + response = await do_rpc_call(socket, request) # Use existing socket connection. else: - response = await do_rpc_call(socket, request, timeout) + response = await do_rpc_call(socket, request) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: if isinstance(response, Exception): @@ -255,8 +283,7 @@ async def _wait_for_server_rpc(self): await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server", - timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS) + error_message="Unable to start RPC Server") async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -308,17 +335,17 @@ async def _is_tracing_enabled_rpc(self) -> bool: async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" - - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") + with suppress(RPCClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") + with suppress(RPCClientClosedError): + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") @property def is_running(self) -> bool: @@ -393,7 +420,6 @@ async def check_health(self, await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_HEALTHY, error_message="Got Unhealthy response from RPC Server", - timeout=VLLM_RPC_HEALTH_TIMEOUT_MS, socket=socket) async def encode(self, *args, diff --git a/vllm/envs.py b/vllm/envs.py index 4f7a7ad7821d5..24e09ee0e055f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -56,6 +56,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False + VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None @@ -374,6 +375,11 @@ def get_default_config_root(): (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in ("1", "true")), + # Time in ms for the zmq client to wait for a response from the backend + # server for simple data operations + "VLLM_RPC_GET_DATA_TIMEOUT_MS": + lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), + # If set, allow running the engine as a separate ray actor, # which is a deprecated feature soon to be removed. # See https://github.com/vllm-project/vllm/issues/7045