Skip to content

Commit cde9183

Browse files
authored
[Bug][Frontend] Improve ZMQ client robustness (vllm-project#7443)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
1 parent df1a211 commit cde9183

File tree

6 files changed

+176
-28
lines changed

6 files changed

+176
-28
lines changed

tests/entrypoints/openai/rpc/__init__.py

Whitespace-only changes.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import asyncio
2+
import tempfile
3+
import unittest
4+
import unittest.mock
5+
import uuid
6+
7+
import pytest
8+
import pytest_asyncio
9+
10+
from vllm.engine.async_llm_engine import AsyncLLMEngine
11+
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
12+
RPCClientClosedError)
13+
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
14+
15+
16+
@pytest.fixture(scope="function")
17+
def tmp_socket():
18+
with tempfile.TemporaryDirectory() as td:
19+
yield f"ipc://{td}/{uuid.uuid4()}"
20+
21+
22+
@pytest_asyncio.fixture(scope="function")
23+
async def dummy_server(tmp_socket, monkeypatch):
24+
dummy_engine = unittest.mock.AsyncMock()
25+
26+
def dummy_engine_builder(*args, **kwargs):
27+
return dummy_engine
28+
29+
with monkeypatch.context() as m:
30+
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
31+
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
32+
33+
loop = asyncio.get_running_loop()
34+
server_task = loop.create_task(server.run_server_loop())
35+
36+
try:
37+
yield server
38+
finally:
39+
server_task.cancel()
40+
server.cleanup()
41+
42+
43+
@pytest_asyncio.fixture(scope="function")
44+
async def client(tmp_socket):
45+
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
46+
# Sanity check: the server is connected
47+
await client._wait_for_server_rpc()
48+
49+
try:
50+
yield client
51+
finally:
52+
client.close()
53+
54+
55+
@pytest.mark.asyncio
56+
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
57+
client: AsyncEngineRPCClient):
58+
with monkeypatch.context() as m:
59+
# Make the server _not_ reply with a model config
60+
m.setattr(dummy_server, "get_config", lambda x: None)
61+
m.setattr(client, "_data_timeout", 10)
62+
63+
# And ensure the task completes anyway
64+
# (client.setup() invokes server.get_config())
65+
client_task = asyncio.get_running_loop().create_task(client.setup())
66+
with pytest.raises(TimeoutError, match="Server didn't reply within"):
67+
await asyncio.wait_for(client_task, timeout=0.05)
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
72+
client: AsyncEngineRPCClient):
73+
with monkeypatch.context() as m:
74+
# Hang all abort requests
75+
m.setattr(dummy_server, "abort", lambda x: None)
76+
m.setattr(client, "_data_timeout", 10)
77+
78+
# Ensure the client doesn't hang
79+
client_task = asyncio.get_running_loop().create_task(
80+
client.abort("test request id"))
81+
with pytest.raises(TimeoutError, match="Server didn't reply within"):
82+
await asyncio.wait_for(client_task, timeout=0.05)
83+
84+
85+
@pytest.mark.asyncio
86+
async def test_client_data_methods_reraise_exceptions(
87+
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
88+
with monkeypatch.context() as m:
89+
# Make the server raise some random exception
90+
exception = RuntimeError("Client test exception")
91+
92+
def raiser():
93+
raise exception
94+
95+
m.setattr(dummy_server.engine, "get_model_config", raiser)
96+
m.setattr(client, "_data_timeout", 10)
97+
98+
client_task = asyncio.get_running_loop().create_task(client.setup())
99+
# And ensure the task completes, raising the exception
100+
with pytest.raises(RuntimeError, match=str(exception)):
101+
await asyncio.wait_for(client_task, timeout=0.05)
102+
103+
104+
@pytest.mark.asyncio
105+
async def test_client_errors_after_closing(monkeypatch, dummy_server,
106+
client: AsyncEngineRPCClient):
107+
108+
client.close()
109+
110+
# Healthchecks and generate requests will fail with explicit errors
111+
with pytest.raises(RPCClientClosedError):
112+
await client.check_health()
113+
with pytest.raises(RPCClientClosedError):
114+
async for _ in client.generate(None, None, None):
115+
pass
116+
117+
# But no-ops like aborting will pass
118+
await client.abort("test-request-id")
119+
await client.do_log_stats()

vllm/entrypoints/openai/api_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import tempfile
88
from argparse import Namespace
9-
from contextlib import asynccontextmanager
9+
from contextlib import asynccontextmanager, suppress
1010
from http import HTTPStatus
1111
from typing import AsyncIterator, Optional, Set
1212

@@ -83,7 +83,8 @@ async def lifespan(app: FastAPI):
8383
async def _force_log():
8484
while True:
8585
await asyncio.sleep(10)
86-
await async_engine_client.do_log_stats()
86+
with suppress(Exception):
87+
await async_engine_client.do_log_stats()
8788

8889
if not engine_args.disable_log_stats:
8990
task = asyncio.create_task(_force_log())

vllm/entrypoints/openai/rpc/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
# Success string used for RPC instructions.
1111
VLLM_RPC_SUCCESS_STR = "SUCCESS"
1212

13-
# Timeouts.
14-
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
15-
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
16-
1713
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
1814
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
1915

vllm/entrypoints/openai/rpc/client.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from contextlib import contextmanager
2+
from contextlib import contextmanager, suppress
33
from typing import Any, AsyncGenerator, Mapping, Optional
44
from uuid import uuid4
55

@@ -11,13 +11,12 @@
1111
ParallelConfig, SchedulerConfig)
1212
# yapf: disable
1313
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
14-
VLLM_RPC_HEALTH_TIMEOUT_MS,
15-
VLLM_RPC_SERVER_START_TIMEOUT_MS,
1614
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
1715
VLLM_RPC_SUCCESS_STR,
1816
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
1917
RPCGenerateRequest, RPCUtilityRequest)
2018
# yapf: enable
19+
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
2120
from vllm.inputs import PromptInputs
2221
from vllm.logger import init_logger
2322
from vllm.lora.request import LoRARequest
@@ -32,6 +31,17 @@
3231
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
3332

3433

34+
class RPCClientClosedError(Exception):
35+
"""Exception class raised when the client is used post-close.
36+
37+
The client can be closed, which closes the ZMQ context. This normally
38+
happens on server shutdown. In some cases, methods like abort and
39+
do_log_stats will still be called and then try to open a socket, which
40+
causes a ZMQError and creates a huge stack trace.
41+
So, we throw this error such that we can suppress it.
42+
"""
43+
44+
3545
class AsyncEngineRPCClient:
3646
"""
3747
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
@@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
8595

8696
def __init__(self, rpc_path: str):
8797
self.context = zmq.asyncio.Context()
98+
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
99+
self._errored = False
88100

89101
# Maximum number of sockets that can be opened (typically 65536).
90102
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
@@ -143,7 +155,6 @@ async def setup(self):
143155

144156
# Wait until server is ready.
145157
await self._wait_for_server_rpc()
146-
self._errored = False
147158

148159
# Get the configs.
149160
self.model_config = await self._get_model_config_rpc()
@@ -170,6 +181,15 @@ def close(self):
170181
@contextmanager
171182
def to_proxy_socket(self):
172183
# Connect to the RPCServer via the proxy.
184+
185+
# Raise a sensible error if the client was already closed.
186+
# This can happen if a server shutdown is triggered but some coroutines
187+
# are still running requests.
188+
# There should not be a race condition with this check because we don't
189+
# yield to the event loop between here and opening the socket.
190+
if self.context.closed:
191+
raise RPCClientClosedError("The ZMQ client has already shut down")
192+
173193
# Note that we use DEALER to enable asynchronous communication
174194
# to enable streaming.
175195
socket = self.context.socket(zmq.constants.DEALER)
@@ -189,9 +209,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
189209
# Ping RPCServer with a request.
190210
await socket.send_multipart([cloudpickle.dumps(request)])
191211

212+
# Make sure the server responds
213+
if await socket.poll(timeout=self._data_timeout) == 0:
214+
raise TimeoutError("Server didn't reply within "
215+
f"{self._data_timeout} ms")
216+
192217
# Await the data from the Server.
193218
data = cloudpickle.loads(await socket.recv())
194219

220+
if isinstance(data, Exception):
221+
# Re-raise exceptions returned by the server
222+
raise data
223+
195224
if not isinstance(data, expected_type):
196225
# LoRAConfig can be None.
197226
if expected_type == LoRAConfig and data is None:
@@ -208,29 +237,28 @@ async def _send_one_way_rpc_request(
208237
self,
209238
request: RPC_REQUEST_TYPE,
210239
error_message: str,
211-
timeout: Optional[int] = None,
212240
socket: Optional[zmq.asyncio.Socket] = None):
213241
"""Send one-way RPC request to trigger an action."""
214242

215243
async def do_rpc_call(socket: zmq.asyncio.Socket,
216-
request: RPC_REQUEST_TYPE,
217-
timeout=None):
244+
request: RPC_REQUEST_TYPE):
218245

219246
await socket.send_multipart([cloudpickle.dumps(request)])
220247

221-
if timeout is not None and await socket.poll(timeout=timeout) == 0:
222-
raise TimeoutError(f"Server didn't reply within {timeout} ms")
248+
if await socket.poll(timeout=self._data_timeout) == 0:
249+
raise TimeoutError("Server didn't reply within "
250+
f"{self._data_timeout} ms")
223251

224252
return cloudpickle.loads(await socket.recv())
225253

226254
# Make a new socket connection.
227255
if socket is None:
228256
with self.to_proxy_socket() as socket:
229-
response = await do_rpc_call(socket, request, timeout)
257+
response = await do_rpc_call(socket, request)
230258

231259
# Use existing socket connection.
232260
else:
233-
response = await do_rpc_call(socket, request, timeout)
261+
response = await do_rpc_call(socket, request)
234262

235263
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
236264
if isinstance(response, Exception):
@@ -255,8 +283,7 @@ async def _wait_for_server_rpc(self):
255283

256284
await self._send_one_way_rpc_request(
257285
request=RPCUtilityRequest.IS_SERVER_READY,
258-
error_message="Unable to start RPC Server",
259-
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
286+
error_message="Unable to start RPC Server")
260287

261288
async def _get_model_config_rpc(self) -> ModelConfig:
262289
"""Get the ModelConfig object from the RPC Server"""
@@ -308,17 +335,17 @@ async def _is_tracing_enabled_rpc(self) -> bool:
308335

309336
async def abort(self, request_id: str):
310337
"""Send an ABORT_REQUEST signal to the RPC Server"""
311-
312-
await self._send_one_way_rpc_request(
313-
request=RPCAbortRequest(request_id),
314-
error_message=f"RPCAbortRequest {request_id} failed")
338+
with suppress(RPCClientClosedError):
339+
await self._send_one_way_rpc_request(
340+
request=RPCAbortRequest(request_id),
341+
error_message=f"RPCAbortRequest {request_id} failed")
315342

316343
async def do_log_stats(self):
317344
"""Send a DO_LOG_STATS signal to the RPC Server"""
318-
319-
await self._send_one_way_rpc_request(
320-
request=RPCUtilityRequest.DO_LOG_STATS,
321-
error_message="RPCRequest DO_LOG_STATS failed.")
345+
with suppress(RPCClientClosedError):
346+
await self._send_one_way_rpc_request(
347+
request=RPCUtilityRequest.DO_LOG_STATS,
348+
error_message="RPCRequest DO_LOG_STATS failed.")
322349

323350
@property
324351
def is_running(self) -> bool:
@@ -393,7 +420,6 @@ async def check_health(self,
393420
await self._send_one_way_rpc_request(
394421
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
395422
error_message="Got Unhealthy response from RPC Server",
396-
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
397423
socket=socket)
398424

399425
async def encode(self, *args,

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
VERBOSE: bool = False
5757
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
5858
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
59+
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
5960
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
6061
VLLM_PLUGINS: Optional[List[str]] = None
6162
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
@@ -374,6 +375,11 @@ def get_default_config_root():
374375
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
375376
("1", "true")),
376377

378+
# Time in ms for the zmq client to wait for a response from the backend
379+
# server for simple data operations
380+
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
381+
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
382+
377383
# If set, allow running the engine as a separate ray actor,
378384
# which is a deprecated feature soon to be removed.
379385
# See https://github.com/vllm-project/vllm/issues/7045

0 commit comments

Comments
 (0)