Skip to content

Commit a39a86a

Browse files
committed
[Core] Get multiprocessing context at runtime
Instead of getting the multiprocessing context at import time, get it at runtime. This allows other code in vllm to change this env var and have it take effect here. Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent cc4325b commit a39a86a

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

vllm/executor/multiproc_worker_utils.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727

2828
JOIN_TIMEOUT_S = 2
2929

30-
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
31-
mp = multiprocessing.get_context(mp_method)
32-
3330

3431
@dataclass
3532
class Result(Generic[T]):
@@ -77,7 +74,7 @@ class ResultHandler(threading.Thread):
7774

7875
def __init__(self) -> None:
7976
super().__init__(daemon=True)
80-
self.result_queue = mp.Queue()
77+
self.result_queue = get_mp_context().Queue()
8178
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
8279

8380
def run(self):
@@ -147,10 +144,11 @@ class ProcessWorkerWrapper:
147144

148145
def __init__(self, result_handler: ResultHandler,
149146
worker_factory: Callable[[], Any]) -> None:
150-
self._task_queue = mp.Queue()
147+
self.mp = get_mp_context()
148+
self._task_queue = self.mp.Queue()
151149
self.result_queue = result_handler.result_queue
152150
self.tasks = result_handler.tasks
153-
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
151+
self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined]
154152
target=_run_worker_process,
155153
name="VllmWorkerProcess",
156154
kwargs=dict(
@@ -204,7 +202,7 @@ def _run_worker_process(
204202
"""Worker process event loop"""
205203

206204
# Add process-specific prefix to stdout and stderr
207-
process_name = mp.current_process().name
205+
process_name = get_mp_context().current_process().name
208206
pid = os.getpid()
209207
_add_prefix(sys.stdout, process_name, pid)
210208
_add_prefix(sys.stderr, process_name, pid)
@@ -269,3 +267,8 @@ def write_with_prefix(s: str):
269267

270268
file.start_new_line = True # type: ignore[attr-defined]
271269
file.write = write_with_prefix # type: ignore[method-assign]
270+
271+
272+
def get_mp_context():
273+
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
274+
return multiprocessing.get_context(mp_method)

0 commit comments

Comments
 (0)