Skip to content

Commit d370579

Browse files
tlrmchlsmthweilong.yu
authored and
weilong.yu
committed
[V1] Multiprocessing Tensor Parallel Support for v1 (vllm-project#9856)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent cff864e commit d370579

21 files changed

+733
-146
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
2727

2828

29+
@pytest.fixture(autouse=True)
30+
def v1(run_with_both_engines):
31+
# Simple autouse wrapper to run both engines for each test
32+
# This can be promoted up to conftest.py to run for every
33+
# test in a package
34+
pass
35+
36+
2937
def test_vllm_gc_ed():
3038
"""Verify vllm instance is GC'ed when it is deleted"""
3139
llm = LLM("facebook/opt-125m")
@@ -36,6 +44,7 @@ def test_vllm_gc_ed():
3644
assert weak_llm() is None
3745

3846

47+
@pytest.mark.skip_v1
3948
@pytest.mark.parametrize("model", MODELS)
4049
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
4150
@pytest.mark.parametrize("dtype", ["half"])
@@ -118,6 +127,11 @@ def test_models_distributed(
118127
if attention_backend:
119128
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend
120129

130+
# Import VLLM_USE_V1 dynamically to handle patching
131+
from vllm.envs import VLLM_USE_V1
132+
if VLLM_USE_V1 and distributed_executor_backend != "mp":
133+
pytest.skip(f"Skip {distributed_executor_backend} for V1")
134+
121135
dtype = "half"
122136
max_tokens = 5
123137

@@ -143,6 +157,7 @@ def test_models_distributed(
143157
)
144158

145159

160+
@pytest.mark.skip_v1
146161
def test_model_with_failure(vllm_runner) -> None:
147162
try:
148163
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
@@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None:
169184
os.remove(filename)
170185

171186

187+
@pytest.mark.skip_v1
172188
def test_failure_with_async_out_proc(vllm_runner) -> None:
173189

174190
filename = None

tests/conftest.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from enum import Enum
66
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
77
TypedDict, TypeVar, Union)
8-
from unittest.mock import patch
98

109
import numpy as np
1110
import pytest
@@ -110,7 +109,7 @@ def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
110109

111110

112111
@pytest.fixture(params=[True, False])
113-
def run_with_both_engines(request):
112+
def run_with_both_engines(request, monkeypatch):
114113
# Automatically runs tests twice, once with V1 and once without
115114
use_v1 = request.param
116115
# Tests decorated with `@skip_v1` are only run without v1
@@ -119,11 +118,11 @@ def run_with_both_engines(request):
119118
if use_v1:
120119
if skip_v1:
121120
pytest.skip("Skipping test on vllm V1")
122-
with patch('vllm.envs.VLLM_USE_V1', True):
123-
yield
121+
monkeypatch.setenv('VLLM_USE_V1', '1')
124122
else:
125-
with patch('vllm.envs.VLLM_USE_V1', False):
126-
yield
123+
monkeypatch.setenv('VLLM_USE_V1', '0')
124+
125+
yield
127126

128127

129128
@pytest.fixture(autouse=True)

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
22
import pickle
3+
import sys
34
import time
45
from contextlib import contextmanager
56
from dataclasses import dataclass, field
67
from multiprocessing import shared_memory
7-
from typing import List, Optional
8+
from typing import List, Optional, Tuple
89
from unittest.mock import patch
910

1011
import torch
@@ -21,6 +22,20 @@
2122

2223
logger = init_logger(__name__)
2324

25+
# We prefer to use os.sched_yield as it results in tighter polling loops,
26+
# measured to be around 3e-7 seconds. However on earlier versions of Python
27+
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
28+
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
29+
or (sys.version_info[:2] == (3, 10)
30+
and sys.version_info[2] >= 8))
31+
32+
33+
def sched_yield():
34+
if USE_SCHED_YIELD:
35+
os.sched_yield()
36+
else:
37+
time.sleep(0)
38+
2439

2540
class ShmRingBuffer:
2641

@@ -114,11 +129,14 @@ def __init__(self,
114129
# and we should suppress the error
115130
pass
116131

132+
def handle(self):
133+
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
134+
self.shared_memory.name)
135+
117136
def __reduce__(self):
118137
return (
119138
self.__class__,
120-
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
121-
self.shared_memory.name),
139+
self.handle(),
122140
)
123141

124142
def __del__(self):
@@ -147,7 +165,7 @@ class Handle:
147165
connect_ip: str
148166
local_reader_ranks: List[int] = field(default_factory=list)
149167

150-
buffer: Optional[ShmRingBuffer] = None
168+
buffer_handle: Optional[Tuple[int, int, int, str]] = None
151169
local_subscribe_port: Optional[int] = None
152170
remote_subscribe_port: Optional[int] = None
153171

@@ -228,7 +246,7 @@ def __init__(
228246
self.handle = Handle(
229247
connect_ip=connect_ip,
230248
local_reader_ranks=local_reader_ranks,
231-
buffer=self.buffer,
249+
buffer_handle=self.buffer.handle(),
232250
local_subscribe_port=local_subscribe_port,
233251
remote_subscribe_port=remote_subscribe_port,
234252
)
@@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
247265
context = Context()
248266

249267
if rank in handle.local_reader_ranks:
250-
assert handle.buffer is not None
251-
self.buffer = handle.buffer
268+
assert handle.buffer_handle is not None
269+
self.buffer = ShmRingBuffer(*handle.buffer_handle)
252270
self.current_idx = 0
253271
self.local_reader_rank = handle.local_reader_ranks.index(rank)
254272
self._is_local_reader = True
@@ -314,7 +332,7 @@ def wait_until_ready(self):
314332
assert recv == b"READY"
315333

316334
@contextmanager
317-
def acquire_write(self):
335+
def acquire_write(self, timeout: Optional[float] = None):
318336
assert self._is_writer, "Only writers can acquire write"
319337
start_time = time.monotonic()
320338
n_warning = 1
@@ -329,16 +347,20 @@ def acquire_write(self):
329347
# we need to wait until it is read by all readers
330348

331349
# Release the processor to other threads
332-
os.sched_yield()
350+
sched_yield()
333351

334-
# if we wait for a long time, we should warn the user
352+
# if we wait for a long time, log a message
335353
if (time.monotonic() - start_time >
336354
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
337-
logger.warning(
338-
"No available block found in %s second. ",
339-
VLLM_RINGBUFFER_WARNING_INTERVAL)
355+
logger.debug("No available block found in %s second. ",
356+
VLLM_RINGBUFFER_WARNING_INTERVAL)
340357
n_warning += 1
341358

359+
# if we time out, raise an exception
360+
if (timeout is not None
361+
and time.monotonic() - start_time > timeout):
362+
raise TimeoutError
363+
342364
continue
343365
# found a block that is either
344366
# (1) not written
@@ -365,7 +387,7 @@ def acquire_write(self):
365387
break
366388

367389
@contextmanager
368-
def acquire_read(self):
390+
def acquire_read(self, timeout: Optional[float] = None):
369391
assert self._is_local_reader, "Only readers can acquire read"
370392
start_time = time.monotonic()
371393
n_warning = 1
@@ -383,16 +405,20 @@ def acquire_read(self):
383405
# we need to wait until it is written
384406

385407
# Release the processor to other threads
386-
os.sched_yield()
408+
sched_yield()
387409

388-
# if we wait for a long time, we should warn the user
410+
# if we wait for a long time, log a message
389411
if (time.monotonic() - start_time >
390412
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
391-
logger.warning(
392-
"No available block found in %s second. ",
393-
VLLM_RINGBUFFER_WARNING_INTERVAL)
413+
logger.debug("No available block found in %s second. ",
414+
VLLM_RINGBUFFER_WARNING_INTERVAL)
394415
n_warning += 1
395416

417+
# if we time out, raise an exception
418+
if (timeout is not None
419+
and time.monotonic() - start_time > timeout):
420+
raise TimeoutError
421+
396422
continue
397423
# found a block that is not read by this reader
398424
# let caller read from the buffer
@@ -406,24 +432,26 @@ def acquire_read(self):
406432
1) % self.buffer.max_chunks
407433
break
408434

409-
def enqueue(self, obj):
435+
def enqueue(self, obj, timeout: Optional[float] = None):
436+
""" Write to message queue with optional timeout (in seconds) """
410437
assert self._is_writer, "Only writers can enqueue"
411438
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
412439
if self.n_local_reader > 0:
413440
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
414-
with self.acquire_write() as buf:
441+
with self.acquire_write(timeout) as buf:
415442
buf[0] = 1 # overflow
416443
self.local_socket.send(serialized_obj)
417444
else:
418-
with self.acquire_write() as buf:
445+
with self.acquire_write(timeout) as buf:
419446
buf[0] = 0 # not overflow
420447
buf[1:len(serialized_obj) + 1] = serialized_obj
421448
if self.n_remote_reader > 0:
422449
self.remote_socket.send(serialized_obj)
423450

424-
def dequeue(self):
451+
def dequeue(self, timeout: Optional[float] = None):
452+
""" Read from message queue with optional timeout (in seconds) """
425453
if self._is_local_reader:
426-
with self.acquire_read() as buf:
454+
with self.acquire_read(timeout) as buf:
427455
overflow = buf[0] == 1
428456
if not overflow:
429457
# no need to know the size of serialized object

vllm/executor/multiproc_gpu_executor.py

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,19 @@
33
from functools import partial
44
from typing import Any, List, Optional
55

6-
import torch
7-
86
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
97
DistributedGPUExecutor, DistributedGPUExecutorAsync)
108
from vllm.executor.gpu_executor import create_worker
11-
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
12-
ResultHandler, WorkerMonitor)
9+
from vllm.executor.multiproc_worker_utils import (
10+
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
11+
set_multiprocessing_worker_envs)
1312
from vllm.logger import init_logger
1413
from vllm.model_executor.layers.sampler import SamplerOutput
1514
from vllm.sequence import ExecuteModelRequest
16-
from vllm.triton_utils.importing import HAS_TRITON
1715
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
18-
cuda_is_initialized, get_distributed_init_method,
19-
get_open_port, make_async,
16+
get_distributed_init_method, get_open_port, make_async,
2017
update_environment_variables)
2118

22-
if HAS_TRITON:
23-
from vllm.triton_utils import maybe_set_triton_cache_manager
24-
2519
logger = init_logger(__name__)
2620

2721

@@ -37,30 +31,8 @@ def _init_executor(self) -> None:
3731
world_size = self.parallel_config.world_size
3832
tensor_parallel_size = self.parallel_config.tensor_parallel_size
3933

40-
# Disable torch async compiling which won't work with daemonic processes
41-
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
42-
43-
# Configure thread parallelism if OMP_NUM_THREADS isn't set
44-
#
45-
# Helps to avoid CPU contention. The default of spawning a thread per
46-
# core combined with multiprocessing for each GPU can have a negative
47-
# impact on performance. The contention is amplified when running in a
48-
# container where CPU limits can cause throttling.
49-
default_omp_num_threads = 1
50-
if "OMP_NUM_THREADS" not in os.environ and (
51-
current_parallelism :=
52-
torch.get_num_threads()) > default_omp_num_threads:
53-
logger.warning(
54-
"Reducing Torch parallelism from %d threads to %d to avoid "
55-
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
56-
"external environment to tune this value as needed.",
57-
current_parallelism, default_omp_num_threads)
58-
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
59-
torch.set_num_threads(default_omp_num_threads)
60-
61-
# workaround for https://github.com/vllm-project/vllm/issues/6103
62-
if HAS_TRITON and world_size > 1:
63-
maybe_set_triton_cache_manager()
34+
# Set multiprocessing envs that are common to V0 and V1
35+
set_multiprocessing_worker_envs(self.parallel_config)
6436

6537
# Multiprocessing-based executor does not support multi-node setting.
6638
# Since it only works for single node, we can use the loopback address
@@ -122,13 +94,6 @@ def _check_executor_parameters(self):
12294
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
12395
})
12496

125-
if (cuda_is_initialized()
126-
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
127-
logger.warning("CUDA was previously initialized. We must use "
128-
"the `spawn` multiprocessing start method. Setting "
129-
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
130-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
131-
13297
cuda_device_count = cuda_device_count_stateless()
13398
# Use confusing message for more common TP-only case.
13499
assert tensor_parallel_size <= cuda_device_count, (

vllm/executor/multiproc_worker_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,15 @@
1111
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
1212
TypeVar, Union)
1313

14+
import torch
15+
1416
import vllm.envs as envs
1517
from vllm.logger import init_logger
18+
from vllm.triton_utils.importing import HAS_TRITON
19+
from vllm.utils import cuda_is_initialized
20+
21+
if HAS_TRITON:
22+
from vllm.triton_utils import maybe_set_triton_cache_manager
1623

1724
logger = init_logger(__name__)
1825

@@ -270,3 +277,38 @@ def write_with_prefix(s: str):
270277
def get_mp_context():
271278
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
272279
return multiprocessing.get_context(mp_method)
280+
281+
282+
def set_multiprocessing_worker_envs(parallel_config):
283+
""" Set up environment variables that should be used when there are workers
284+
in a multiprocessing environment. This should be called by the parent
285+
process before worker processes are created"""
286+
287+
if (cuda_is_initialized()
288+
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
289+
logger.warning("CUDA was previously initialized. We must use "
290+
"the `spawn` multiprocessing start method. Setting "
291+
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
292+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
293+
294+
# Configure thread parallelism if OMP_NUM_THREADS isn't set
295+
#
296+
# Helps to avoid CPU contention. The default of spawning a thread per
297+
# core combined with multiprocessing for each GPU can have a negative
298+
# impact on performance. The contention is amplified when running in a
299+
# container where CPU limits can cause throttling.
300+
default_omp_num_threads = 1
301+
if "OMP_NUM_THREADS" not in os.environ and (
302+
current_parallelism :=
303+
torch.get_num_threads()) > default_omp_num_threads:
304+
logger.warning(
305+
"Reducing Torch parallelism from %d threads to %d to avoid "
306+
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
307+
"external environment to tune this value as needed.",
308+
current_parallelism, default_omp_num_threads)
309+
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
310+
torch.set_num_threads(default_omp_num_threads)
311+
312+
# workaround for https://github.com/vllm-project/vllm/issues/6103
313+
if HAS_TRITON and parallel_config.world_size > 1:
314+
maybe_set_triton_cache_manager()

0 commit comments

Comments
 (0)