Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 56d2002

Browse files
njhillRobert Shaw
authored andcommitted
[Core] Add multiproc_worker_utils for multiprocessing-based workers (vllm-project#4357)
1 parent 3d32972 commit 56d2002

File tree

2 files changed

+440
-0
lines changed

2 files changed

+440
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import asyncio
2+
from concurrent.futures import ThreadPoolExecutor
3+
from functools import partial
4+
from time import sleep
5+
from typing import Any, List, Tuple
6+
7+
import pytest
8+
9+
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
10+
ResultHandler, WorkerMonitor)
11+
12+
13+
class DummyWorker:
14+
"""Dummy version of vllm.worker.worker.Worker"""
15+
16+
def __init__(self, rank: int):
17+
self.rank = rank
18+
19+
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
20+
sleep(0.05)
21+
22+
if isinstance(worker_input, Exception):
23+
# simulate error case
24+
raise worker_input
25+
26+
return self.rank, input
27+
28+
29+
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
30+
result_handler = ResultHandler()
31+
workers = [
32+
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
33+
for rank in range(8)
34+
]
35+
36+
worker_monitor = WorkerMonitor(workers, result_handler)
37+
assert not worker_monitor.is_alive()
38+
39+
result_handler.start()
40+
worker_monitor.start()
41+
assert worker_monitor.is_alive()
42+
43+
return workers, worker_monitor
44+
45+
46+
def test_local_workers() -> None:
47+
"""Test workers with sync task submission"""
48+
49+
workers, worker_monitor = _start_workers()
50+
51+
def execute_workers(worker_input: str) -> None:
52+
worker_outputs = [
53+
worker.execute_method("worker_method", worker_input)
54+
for worker in workers
55+
]
56+
57+
for rank, output in enumerate(worker_outputs):
58+
assert output.get() == (rank, input)
59+
60+
executor = ThreadPoolExecutor(max_workers=4)
61+
62+
# Test concurrent submission from different threads
63+
futures = [
64+
executor.submit(partial(execute_workers, f"thread {thread_num}"))
65+
for thread_num in range(4)
66+
]
67+
68+
for future in futures:
69+
future.result()
70+
71+
# Test error case
72+
exception = ValueError("fake error")
73+
result = workers[0].execute_method("worker_method", exception)
74+
try:
75+
result.get()
76+
pytest.fail("task should have failed")
77+
except Exception as e:
78+
assert isinstance(e, ValueError)
79+
assert str(e) == "fake error"
80+
81+
# Test cleanup when a worker fails
82+
assert worker_monitor.is_alive()
83+
workers[3].process.kill()
84+
85+
# Other workers should get shut down here
86+
worker_monitor.join(2)
87+
88+
# Ensure everything is stopped
89+
assert not worker_monitor.is_alive()
90+
assert all(not worker.process.is_alive() for worker in workers)
91+
92+
# Further attempts to submit tasks should fail
93+
try:
94+
_result = workers[0].execute_method("worker_method", "test")
95+
pytest.fail("task should fail once workers have been shut down")
96+
except Exception as e:
97+
assert isinstance(e, ChildProcessError)
98+
99+
100+
def test_local_workers_clean_shutdown() -> None:
101+
"""Test clean shutdown"""
102+
103+
workers, worker_monitor = _start_workers()
104+
105+
assert worker_monitor.is_alive()
106+
assert all(worker.process.is_alive() for worker in workers)
107+
108+
# Clean shutdown
109+
worker_monitor.close()
110+
111+
worker_monitor.join(5)
112+
113+
# Ensure everything is stopped
114+
assert not worker_monitor.is_alive()
115+
assert all(not worker.process.is_alive() for worker in workers)
116+
117+
# Further attempts to submit tasks should fail
118+
try:
119+
_result = workers[0].execute_method("worker_method", "test")
120+
pytest.fail("task should fail once workers have been shut down")
121+
except Exception as e:
122+
assert isinstance(e, ChildProcessError)
123+
124+
125+
@pytest.mark.asyncio
126+
async def test_local_workers_async() -> None:
127+
"""Test local workers with async task submission"""
128+
129+
workers, worker_monitor = _start_workers()
130+
131+
async def execute_workers(worker_input: str) -> None:
132+
worker_coros = [
133+
worker.execute_method_async("worker_method", worker_input)
134+
for worker in workers
135+
]
136+
137+
results = await asyncio.gather(*worker_coros)
138+
for rank, result in enumerate(results):
139+
assert result == (rank, input)
140+
141+
tasks = [
142+
asyncio.create_task(execute_workers(f"task {task_num}"))
143+
for task_num in range(4)
144+
]
145+
146+
for task in tasks:
147+
await task
148+
149+
# Test error case
150+
exception = ValueError("fake error")
151+
try:
152+
_result = await workers[0].execute_method_async(
153+
"worker_method", exception)
154+
pytest.fail("task should have failed")
155+
except Exception as e:
156+
assert isinstance(e, ValueError)
157+
assert str(e) == "fake error"
158+
159+
# Test cleanup when a worker fails
160+
assert worker_monitor.is_alive()
161+
workers[3].process.kill()
162+
163+
# Other workers should get shut down here
164+
worker_monitor.join(2)
165+
166+
# Ensure everything is stopped
167+
assert not worker_monitor.is_alive()
168+
assert all(not worker.process.is_alive() for worker in workers)
169+
170+
# Further attempts to submit tasks should fail
171+
try:
172+
_result = await workers[0].execute_method_async(
173+
"worker_method", "test")
174+
pytest.fail("task should fail once workers have been shut down")
175+
except Exception as e:
176+
assert isinstance(e, ChildProcessError)

0 commit comments

Comments
 (0)