1
1
import asyncio
2
- from contextlib import contextmanager
2
+ from contextlib import contextmanager , suppress
3
3
from typing import Any , AsyncGenerator , Mapping , Optional
4
4
from uuid import uuid4
5
5
11
11
ParallelConfig , SchedulerConfig )
12
12
# yapf: disable
13
13
from vllm .entrypoints .openai .rpc import (RPC_REQUEST_TYPE ,
14
- VLLM_RPC_HEALTH_TIMEOUT_MS ,
15
- VLLM_RPC_SERVER_START_TIMEOUT_MS ,
16
14
VLLM_RPC_SOCKET_LIMIT_CUTOFF ,
17
15
VLLM_RPC_SUCCESS_STR ,
18
16
VLLM_RPC_ZMQ_HWM , RPCAbortRequest ,
19
17
RPCGenerateRequest , RPCUtilityRequest )
20
18
# yapf: enable
19
+ from vllm .envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
21
20
from vllm .inputs import PromptInputs
22
21
from vllm .logger import init_logger
23
22
from vllm .lora .request import LoRARequest
32
31
INPROC_PROXY_PATH = f"inproc://{ uuid4 ()} "
33
32
34
33
34
+ class RPCClientClosedError (Exception ):
35
+ """Exception class raised when the client is used post-close.
36
+
37
+ The client can be closed, which closes the ZMQ context. This normally
38
+ happens on server shutdown. In some cases, methods like abort and
39
+ do_log_stats will still be called and then try to open a socket, which
40
+ causes a ZMQError and creates a huge stack trace.
41
+ So, we throw this error such that we can suppress it.
42
+ """
43
+
44
+
35
45
class AsyncEngineRPCClient :
36
46
"""
37
47
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
@@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
85
95
86
96
def __init__ (self , rpc_path : str ):
87
97
self .context = zmq .asyncio .Context ()
98
+ self ._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
99
+ self ._errored = False
88
100
89
101
# Maximum number of sockets that can be opened (typically 65536).
90
102
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
@@ -143,7 +155,6 @@ async def setup(self):
143
155
144
156
# Wait until server is ready.
145
157
await self ._wait_for_server_rpc ()
146
- self ._errored = False
147
158
148
159
# Get the configs.
149
160
self .model_config = await self ._get_model_config_rpc ()
@@ -170,6 +181,15 @@ def close(self):
170
181
@contextmanager
171
182
def to_proxy_socket (self ):
172
183
# Connect to the RPCServer via the proxy.
184
+
185
+ # Raise a sensible error if the client was already closed.
186
+ # This can happen if a server shutdown is triggered but some coroutines
187
+ # are still running requests.
188
+ # There should not be a race condition with this check because we don't
189
+ # yield to the event loop between here and opening the socket.
190
+ if self .context .closed :
191
+ raise RPCClientClosedError ("The ZMQ client has already shut down" )
192
+
173
193
# Note that we use DEALER to enable asynchronous communication
174
194
# to enable streaming.
175
195
socket = self .context .socket (zmq .constants .DEALER )
@@ -189,9 +209,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
189
209
# Ping RPCServer with a request.
190
210
await socket .send_multipart ([cloudpickle .dumps (request )])
191
211
212
+ # Make sure the server responds
213
+ if await socket .poll (timeout = self ._data_timeout ) == 0 :
214
+ raise TimeoutError ("Server didn't reply within "
215
+ f"{ self ._data_timeout } ms" )
216
+
192
217
# Await the data from the Server.
193
218
data = cloudpickle .loads (await socket .recv ())
194
219
220
+ if isinstance (data , Exception ):
221
+ # Re-raise exceptions returned by the server
222
+ raise data
223
+
195
224
if not isinstance (data , expected_type ):
196
225
# LoRAConfig can be None.
197
226
if expected_type == LoRAConfig and data is None :
@@ -208,29 +237,28 @@ async def _send_one_way_rpc_request(
208
237
self ,
209
238
request : RPC_REQUEST_TYPE ,
210
239
error_message : str ,
211
- timeout : Optional [int ] = None ,
212
240
socket : Optional [zmq .asyncio .Socket ] = None ):
213
241
"""Send one-way RPC request to trigger an action."""
214
242
215
243
async def do_rpc_call (socket : zmq .asyncio .Socket ,
216
- request : RPC_REQUEST_TYPE ,
217
- timeout = None ):
244
+ request : RPC_REQUEST_TYPE ):
218
245
219
246
await socket .send_multipart ([cloudpickle .dumps (request )])
220
247
221
- if timeout is not None and await socket .poll (timeout = timeout ) == 0 :
222
- raise TimeoutError (f"Server didn't reply within { timeout } ms" )
248
+ if await socket .poll (timeout = self ._data_timeout ) == 0 :
249
+ raise TimeoutError ("Server didn't reply within "
250
+ f"{ self ._data_timeout } ms" )
223
251
224
252
return cloudpickle .loads (await socket .recv ())
225
253
226
254
# Make a new socket connection.
227
255
if socket is None :
228
256
with self .to_proxy_socket () as socket :
229
- response = await do_rpc_call (socket , request , timeout )
257
+ response = await do_rpc_call (socket , request )
230
258
231
259
# Use existing socket connection.
232
260
else :
233
- response = await do_rpc_call (socket , request , timeout )
261
+ response = await do_rpc_call (socket , request )
234
262
235
263
if not isinstance (response , str ) or response != VLLM_RPC_SUCCESS_STR :
236
264
if isinstance (response , Exception ):
@@ -255,8 +283,7 @@ async def _wait_for_server_rpc(self):
255
283
256
284
await self ._send_one_way_rpc_request (
257
285
request = RPCUtilityRequest .IS_SERVER_READY ,
258
- error_message = "Unable to start RPC Server" ,
259
- timeout = VLLM_RPC_SERVER_START_TIMEOUT_MS )
286
+ error_message = "Unable to start RPC Server" )
260
287
261
288
async def _get_model_config_rpc (self ) -> ModelConfig :
262
289
"""Get the ModelConfig object from the RPC Server"""
@@ -308,17 +335,17 @@ async def _is_tracing_enabled_rpc(self) -> bool:
308
335
309
336
async def abort (self , request_id : str ):
310
337
"""Send an ABORT_REQUEST signal to the RPC Server"""
311
-
312
- await self ._send_one_way_rpc_request (
313
- request = RPCAbortRequest (request_id ),
314
- error_message = f"RPCAbortRequest { request_id } failed" )
338
+ with suppress ( RPCClientClosedError ):
339
+ await self ._send_one_way_rpc_request (
340
+ request = RPCAbortRequest (request_id ),
341
+ error_message = f"RPCAbortRequest { request_id } failed" )
315
342
316
343
async def do_log_stats (self ):
317
344
"""Send a DO_LOG_STATS signal to the RPC Server"""
318
-
319
- await self ._send_one_way_rpc_request (
320
- request = RPCUtilityRequest .DO_LOG_STATS ,
321
- error_message = "RPCRequest DO_LOG_STATS failed." )
345
+ with suppress ( RPCClientClosedError ):
346
+ await self ._send_one_way_rpc_request (
347
+ request = RPCUtilityRequest .DO_LOG_STATS ,
348
+ error_message = "RPCRequest DO_LOG_STATS failed." )
322
349
323
350
@property
324
351
def is_running (self ) -> bool :
@@ -393,7 +420,6 @@ async def check_health(self,
393
420
await self ._send_one_way_rpc_request (
394
421
request = RPCUtilityRequest .IS_SERVER_HEALTHY ,
395
422
error_message = "Got Unhealthy response from RPC Server" ,
396
- timeout = VLLM_RPC_HEALTH_TIMEOUT_MS ,
397
423
socket = socket )
398
424
399
425
async def encode (self , * args ,
0 commit comments