Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,28 @@ def cancel_tasks():


def cancel_task_threadsafe(task: Task):
if task and not task.done() and not (loop := task.get_loop()).is_closed():
loop.call_soon_threadsafe(task.cancel)
if task and not task.done():
run_in_loop(task.get_loop(), task.cancel)


def close_sockets(sockets: Sequence[Union[zmq.Socket, zmq.asyncio.Socket]]):
for sock in sockets:
if sock is not None:
sock.close(linger=0)


def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
if in_loop(loop):
function(*args)
elif not loop.is_closed():
loop.call_soon_threadsafe(function, *args)


def in_loop(event_loop: AbstractEventLoop) -> bool:
try:
return asyncio.get_running_loop() == event_loop
except RuntimeError:
return False


def make_async(
Expand Down
79 changes: 55 additions & 24 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils import (cancel_task_threadsafe, get_open_port,
get_open_zmq_inproc_path, make_zmq_socket)
from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path,
in_loop, make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
Expand Down Expand Up @@ -317,7 +317,7 @@ class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""

ctx: Union[zmq.Context]
ctx: zmq.Context
# If CoreEngineProcManager, it manages local engines;
# if CoreEngineActorManager, it manages all engines.
engine_manager: Optional[Union[CoreEngineProcManager,
Expand All @@ -326,6 +326,8 @@ class BackgroundResources:
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
first_req_send_socket: Optional[zmq.asyncio.Socket] = None
first_req_rcv_socket: Optional[zmq.asyncio.Socket] = None
stats_update_socket: Optional[zmq.asyncio.Socket] = None
output_queue_task: Optional[asyncio.Task] = None
stats_update_task: Optional[asyncio.Task] = None
shutdown_path: Optional[str] = None
Expand All @@ -343,23 +345,47 @@ def __call__(self):
if self.coordinator is not None:
self.coordinator.close()

cancel_task_threadsafe(self.output_queue_task)
cancel_task_threadsafe(self.stats_update_task)
if isinstance(self.output_socket, zmq.asyncio.Socket):
# Async case.
loop = self.output_socket._get_loop()
asyncio.get_running_loop()
sockets = (self.output_socket, self.input_socket,
self.first_req_send_socket, self.first_req_rcv_socket,
self.stats_update_socket)

tasks = (self.output_queue_task, self.stats_update_task)

def close_sockets_and_tasks():
close_sockets(sockets)
for task in tasks:
if task is not None and not task.done():
task.cancel()

if in_loop(loop):
close_sockets_and_tasks()
elif not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks)
else:
# Loop has been closed, try to clean up directly.
del tasks
del close_sockets_and_tasks
close_sockets(sockets)
del self.output_queue_task
del self.stats_update_task
else:
# Sync case.

# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
for socket in (self.output_socket, self.input_socket,
self.first_req_send_socket):
if socket is not None:
socket.close(linger=0)
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
close_sockets((self.output_socket, self.input_socket))

if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
shutdown_sender.connect(self.shutdown_path)
# Send shutdown signal.
shutdown_sender.send(b'')
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
shutdown_sender.connect(self.shutdown_path)
# Send shutdown signal.
shutdown_sender.send(b'')

def validate_alive(self, frames: Sequence[zmq.Frame]):
if len(frames) == 1 and (frames[0].buffer
Expand Down Expand Up @@ -969,14 +995,19 @@ def _ensure_stats_update_task(self):
self.engine_ranks_managed[-1] + 1)

async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address,
zmq.XSUB) as socket, make_zmq_socket(
self.ctx,
self.first_req_sock_addr,
zmq.PAIR,
bind=False) as first_req_rcv_socket:
with (make_zmq_socket(self.ctx,
self.stats_update_address,
zmq.XSUB,
linger=0) as socket,
make_zmq_socket(self.ctx,
self.first_req_sock_addr,
zmq.PAIR,
bind=False,
linger=0) as first_req_rcv_socket):
assert isinstance(socket, zmq.asyncio.Socket)
assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket)
self.resources.stats_update_socket = socket
self.resources.first_req_rcv_socket = first_req_rcv_socket
# Send subscription message.
await socket.send(b'\x01')

Expand Down