Skip to content

Commit adb3867

Browse files
youkaichaoAlvant
authored andcommitted
[bugfix][distributed] fix shm broadcast when the queue size is full (vllm-project#5801)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent de34ae1 commit adb3867

File tree

2 files changed

+76
-46
lines changed

2 files changed

+76
-46
lines changed

tests/distributed/test_shm_broadcast.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
import multiprocessing
22
import random
33
import time
4+
from typing import List
45

6+
import numpy as np
57
import torch.distributed as dist
68

79
from vllm.distributed.device_communicators.shm_broadcast import (
810
ShmRingBuffer, ShmRingBufferIO)
911
from vllm.utils import update_environment_variables
1012

1113

14+
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
15+
np.random.seed(seed)
16+
sizes = np.random.randint(1, 10_000, n)
17+
# on average, each array will have 5k elements
18+
# with int64, each array will have 40kb
19+
return [np.random.randint(1, 100, i) for i in sizes]
20+
21+
1222
def distributed_run(fn, world_size):
1323
number_of_processes = world_size
1424
processes = []
@@ -47,24 +57,31 @@ def wrapped_fn(env):
4757
def worker_fn():
4858
writer_rank = 2
4959
broadcaster = ShmRingBufferIO.create_from_process_group(
50-
dist.group.WORLD, 1024, 2, writer_rank)
60+
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
61+
if dist.get_rank() == writer_rank:
62+
seed = random.randint(0, 1000)
63+
dist.broadcast_object_list([seed], writer_rank)
64+
else:
65+
recv = [None]
66+
dist.broadcast_object_list(recv, writer_rank)
67+
seed = recv[0] # type: ignore
68+
dist.barrier()
69+
# in case we find a race condition
70+
# print the seed so that we can reproduce the error
71+
print(f"Rank {dist.get_rank()} got seed {seed}")
72+
# test broadcasting with about 400MB of data
73+
N = 10_000
5174
if dist.get_rank() == writer_rank:
52-
time.sleep(random.random())
53-
broadcaster.broadcast_object(0)
54-
time.sleep(random.random())
55-
broadcaster.broadcast_object({})
56-
time.sleep(random.random())
57-
broadcaster.broadcast_object([])
75+
arrs = get_arrays(N, seed)
76+
for x in arrs:
77+
broadcaster.broadcast_object(x)
78+
time.sleep(random.random() / 1000)
5879
else:
59-
time.sleep(random.random())
60-
a = broadcaster.broadcast_object(None)
61-
time.sleep(random.random())
62-
b = broadcaster.broadcast_object(None)
63-
time.sleep(random.random())
64-
c = broadcaster.broadcast_object(None)
65-
assert a == 0
66-
assert b == {}
67-
assert c == []
80+
arrs = get_arrays(N, seed)
81+
for x in arrs:
82+
y = broadcaster.broadcast_object(None)
83+
assert np.array_equal(x, y)
84+
time.sleep(random.random() / 1000)
6885
dist.barrier()
6986

7087

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
1616

17+
# time to wait if the queue is full or empty
18+
# if we sleep for too short, it will consume too much CPU
19+
# if we sleep for too long, it will slow down the writer/reader
20+
# 0.1 us is a good balance
21+
RINGBUFFER_SLEEP_INTERVAL = 1e-7
22+
1723
logger = init_logger(__name__)
1824

1925

@@ -145,28 +151,29 @@ def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
145151
@contextmanager
146152
def acquire_write(self):
147153
assert self._is_writer, "Only writers can acquire write"
148-
start_index = self.current_idx
149-
start_time = time.time()
154+
start_time = time.monotonic()
150155
n_warning = 1
151156
while True:
152157
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
153158
read_count = sum(metadata_buffer[1:])
154159
written_flag = metadata_buffer[0]
155160
if written_flag and read_count != self.buffer.n_reader:
156161
# this block is written and not read by all readers
157-
# try to write to the next block
158-
self.current_idx = (self.current_idx +
159-
1) % self.buffer.max_chunks
160-
if self.current_idx == start_index:
161-
# no empty block found
162-
if time.time(
163-
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
164-
logger.warning(
165-
"No available block found in %s second. ",
166-
VLLM_RINGBUFFER_WARNING_INTERVAL)
167-
n_warning += 1
168-
# wait for a while (0.1 us)
169-
time.sleep(1e-7)
162+
# for writers, `self.current_idx` is the next block to write
163+
# if this block is not ready to write,
164+
# we need to wait until it is read by all readers
165+
166+
# wait for a while
167+
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
168+
169+
# if we wait for a long time, we should warn the user
170+
if time.monotonic(
171+
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
172+
logger.warning(
173+
"No available block found in %s second. ",
174+
VLLM_RINGBUFFER_WARNING_INTERVAL)
175+
n_warning += 1
176+
170177
continue
171178
# found a block that is either
172179
# (1) not written
@@ -188,13 +195,14 @@ def acquire_write(self):
188195
metadata_buffer[i] = 0
189196
# mark the block as written
190197
metadata_buffer[0] = 1
198+
self.current_idx = (self.current_idx +
199+
1) % self.buffer.max_chunks
191200
break
192201

193202
@contextmanager
194203
def acquire_read(self):
195204
assert self._is_reader, "Only readers can acquire read"
196-
start_index = self.current_idx
197-
start_time = time.time()
205+
start_time = time.monotonic()
198206
n_warning = 1
199207
while True:
200208
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
@@ -204,19 +212,22 @@ def acquire_read(self):
204212
# this block is either
205213
# (1) not written
206214
# (2) already read by this reader
207-
# try to read the next block
208-
self.current_idx = (self.current_idx +
209-
1) % self.buffer.max_chunks
210-
if self.current_idx == start_index:
211-
# no block found
212-
if time.time(
213-
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
214-
logger.warning(
215-
"No available block found in %s second. ",
216-
VLLM_RINGBUFFER_WARNING_INTERVAL)
217-
n_warning += 1
218-
# wait for a while (0.1 us)
219-
time.sleep(1e-7)
215+
216+
# for readers, `self.current_idx` is the next block to read
217+
# if this block is not ready,
218+
# we need to wait until it is written
219+
220+
# wait for a while
221+
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
222+
223+
# if we wait for a long time, we should warn the user
224+
if time.monotonic(
225+
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
226+
logger.warning(
227+
"No available block found in %s second. ",
228+
VLLM_RINGBUFFER_WARNING_INTERVAL)
229+
n_warning += 1
230+
220231
continue
221232
# found a block that is not read by this reader
222233
# let caller read from the buffer
@@ -226,6 +237,8 @@ def acquire_read(self):
226237
# caller has read from the buffer
227238
# set the read flag
228239
metadata_buffer[self.reader_rank + 1] = 1
240+
self.current_idx = (self.current_idx +
241+
1) % self.buffer.max_chunks
229242
break
230243

231244
def enqueue(self, obj):

0 commit comments

Comments
 (0)