Skip to content

Commit

Permalink
[Bug][Frontend] Improve ZMQ client robustness (vllm-project#7443)
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
  • Loading branch information
joerunde authored Aug 22, 2024
1 parent df1a211 commit cde9183
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 28 deletions.
Empty file.
119 changes: 119 additions & 0 deletions tests/entrypoints/openai/rpc/test_zmq_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import asyncio
import tempfile
import unittest
import unittest.mock
import uuid

import pytest
import pytest_asyncio

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
RPCClientClosedError)
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer


@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"


@pytest_asyncio.fixture(scope="function")
async def dummy_server(tmp_socket, monkeypatch):
dummy_engine = unittest.mock.AsyncMock()

def dummy_engine_builder(*args, **kwargs):
return dummy_engine

with monkeypatch.context() as m:
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)

loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())

try:
yield server
finally:
server_task.cancel()
server.cleanup()


@pytest_asyncio.fixture(scope="function")
async def client(tmp_socket):
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
# Sanity check: the server is connected
await client._wait_for_server_rpc()

try:
yield client
finally:
client.close()


@pytest.mark.asyncio
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server _not_ reply with a model config
m.setattr(dummy_server, "get_config", lambda x: None)
m.setattr(client, "_data_timeout", 10)

# And ensure the task completes anyway
# (client.setup() invokes server.get_config())
client_task = asyncio.get_running_loop().create_task(client.setup())
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Hang all abort requests
m.setattr(dummy_server, "abort", lambda x: None)
m.setattr(client, "_data_timeout", 10)

# Ensure the client doesn't hang
client_task = asyncio.get_running_loop().create_task(
client.abort("test request id"))
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_data_methods_reraise_exceptions(
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server raise some random exception
exception = RuntimeError("Client test exception")

def raiser():
raise exception

m.setattr(dummy_server.engine, "get_model_config", raiser)
m.setattr(client, "_data_timeout", 10)

client_task = asyncio.get_running_loop().create_task(client.setup())
# And ensure the task completes, raising the exception
with pytest.raises(RuntimeError, match=str(exception)):
await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_errors_after_closing(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):

client.close()

# Healthchecks and generate requests will fail with explicit errors
with pytest.raises(RPCClientClosedError):
await client.check_health()
with pytest.raises(RPCClientClosedError):
async for _ in client.generate(None, None, None):
pass

# But no-ops like aborting will pass
await client.abort("test-request-id")
await client.do_log_stats()
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set

Expand Down Expand Up @@ -83,7 +83,8 @@ async def lifespan(app: FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
with suppress(Exception):
await async_engine_client.do_log_stats()

if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
Expand Down
4 changes: 0 additions & 4 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS"

# Timeouts.
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000

# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000

Expand Down
70 changes: 48 additions & 22 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from uuid import uuid4

Expand All @@ -11,13 +11,12 @@
ParallelConfig, SchedulerConfig)
# yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTH_TIMEOUT_MS,
VLLM_RPC_SERVER_START_TIMEOUT_MS,
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand All @@ -32,6 +31,17 @@
INPROC_PROXY_PATH = f"inproc://{uuid4()}"


class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""


class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
Expand Down Expand Up @@ -85,6 +95,8 @@ class AsyncEngineRPCClient:

def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
self._errored = False

# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
Expand Down Expand Up @@ -143,7 +155,6 @@ async def setup(self):

# Wait until server is ready.
await self._wait_for_server_rpc()
self._errored = False

# Get the configs.
self.model_config = await self._get_model_config_rpc()
Expand All @@ -170,6 +181,15 @@ def close(self):
@contextmanager
def to_proxy_socket(self):
# Connect to the RPCServer via the proxy.

# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")

# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
Expand All @@ -189,9 +209,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])

# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())

if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data

if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
Expand All @@ -208,29 +237,28 @@ async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None,
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action."""

async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE,
timeout=None):
request: RPC_REQUEST_TYPE):

await socket.send_multipart([cloudpickle.dumps(request)])

if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"Server didn't reply within {timeout} ms")
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

return cloudpickle.loads(await socket.recv())

# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)

# Use existing socket connection.
else:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)

if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception):
Expand All @@ -255,8 +283,7 @@ async def _wait_for_server_rpc(self):

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server",
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
error_message="Unable to start RPC Server")

async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
Expand Down Expand Up @@ -308,17 +335,17 @@ async def _is_tracing_enabled_rpc(self) -> bool:

async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""

await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")

async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")

@property
def is_running(self) -> bool:
Expand Down Expand Up @@ -393,7 +420,6 @@ async def check_health(self,
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
socket=socket)

async def encode(self, *args,
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
Expand Down Expand Up @@ -374,6 +375,11 @@ def get_default_config_root():
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),

# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),

# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
Expand Down

0 comments on commit cde9183

Please sign in to comment.