Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit cb364ef

Browse files
youkaichaoRobert Shaw
authored andcommitted
[bugfix][distributed] fix shm broadcast when the queue size is full (vllm-project#5801)
1 parent ce9da79 commit cb364ef

File tree

2 files changed

+76
-46
lines changed

2 files changed

+76
-46
lines changed

tests/distributed/test_shm_broadcast.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import multiprocessing
22
import random
33
import time
4+
from typing import List
45

6+
import numpy as np
57
import pytest
68
import torch.distributed as dist
79

@@ -15,6 +17,14 @@
1517
allow_module_level=True)
1618

1719

20+
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
21+
np.random.seed(seed)
22+
sizes = np.random.randint(1, 10_000, n)
23+
# on average, each array will have 5k elements
24+
# with int64, each array will have 40kb
25+
return [np.random.randint(1, 100, i) for i in sizes]
26+
27+
1828
def distributed_run(fn, world_size):
1929
number_of_processes = world_size
2030
processes = []
@@ -53,24 +63,31 @@ def wrapped_fn(env):
5363
def worker_fn():
5464
writer_rank = 2
5565
broadcaster = ShmRingBufferIO.create_from_process_group(
56-
dist.group.WORLD, 1024, 2, writer_rank)
66+
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
67+
if dist.get_rank() == writer_rank:
68+
seed = random.randint(0, 1000)
69+
dist.broadcast_object_list([seed], writer_rank)
70+
else:
71+
recv = [None]
72+
dist.broadcast_object_list(recv, writer_rank)
73+
seed = recv[0] # type: ignore
74+
dist.barrier()
75+
# in case we find a race condition
76+
# print the seed so that we can reproduce the error
77+
print(f"Rank {dist.get_rank()} got seed {seed}")
78+
# test broadcasting with about 400MB of data
79+
N = 10_000
5780
if dist.get_rank() == writer_rank:
58-
time.sleep(random.random())
59-
broadcaster.broadcast_object(0)
60-
time.sleep(random.random())
61-
broadcaster.broadcast_object({})
62-
time.sleep(random.random())
63-
broadcaster.broadcast_object([])
81+
arrs = get_arrays(N, seed)
82+
for x in arrs:
83+
broadcaster.broadcast_object(x)
84+
time.sleep(random.random() / 1000)
6485
else:
65-
time.sleep(random.random())
66-
a = broadcaster.broadcast_object(None)
67-
time.sleep(random.random())
68-
b = broadcaster.broadcast_object(None)
69-
time.sleep(random.random())
70-
c = broadcaster.broadcast_object(None)
71-
assert a == 0
72-
assert b == {}
73-
assert c == []
86+
arrs = get_arrays(N, seed)
87+
for x in arrs:
88+
y = broadcaster.broadcast_object(None)
89+
assert np.array_equal(x, y)
90+
time.sleep(random.random() / 1000)
7491
dist.barrier()
7592

7693

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
1616

17+
# time to wait if the queue is full or empty
18+
# if we sleep for too short, it will consume too much CPU
19+
# if we sleep for too long, it will slow down the writer/reader
20+
# 0.1 us is a good balance
21+
RINGBUFFER_SLEEP_INTERVAL = 1e-7
22+
1723
logger = init_logger(__name__)
1824

1925

@@ -145,28 +151,29 @@ def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
145151
@contextmanager
146152
def acquire_write(self):
147153
assert self._is_writer, "Only writers can acquire write"
148-
start_index = self.current_idx
149-
start_time = time.time()
154+
start_time = time.monotonic()
150155
n_warning = 1
151156
while True:
152157
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
153158
read_count = sum(metadata_buffer[1:])
154159
written_flag = metadata_buffer[0]
155160
if written_flag and read_count != self.buffer.n_reader:
156161
# this block is written and not read by all readers
157-
# try to write to the next block
158-
self.current_idx = (self.current_idx +
159-
1) % self.buffer.max_chunks
160-
if self.current_idx == start_index:
161-
# no empty block found
162-
if time.time(
163-
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
164-
logger.warning(
165-
"No available block found in %s second. ",
166-
VLLM_RINGBUFFER_WARNING_INTERVAL)
167-
n_warning += 1
168-
# wait for a while (0.1 us)
169-
time.sleep(1e-7)
162+
# for writers, `self.current_idx` is the next block to write
163+
# if this block is not ready to write,
164+
# we need to wait until it is read by all readers
165+
166+
# wait for a while
167+
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
168+
169+
# if we wait for a long time, we should warn the user
170+
if time.monotonic(
171+
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
172+
logger.warning(
173+
"No available block found in %s second. ",
174+
VLLM_RINGBUFFER_WARNING_INTERVAL)
175+
n_warning += 1
176+
170177
continue
171178
# found a block that is either
172179
# (1) not written
@@ -188,13 +195,14 @@ def acquire_write(self):
188195
metadata_buffer[i] = 0
189196
# mark the block as written
190197
metadata_buffer[0] = 1
198+
self.current_idx = (self.current_idx +
199+
1) % self.buffer.max_chunks
191200
break
192201

193202
@contextmanager
194203
def acquire_read(self):
195204
assert self._is_reader, "Only readers can acquire read"
196-
start_index = self.current_idx
197-
start_time = time.time()
205+
start_time = time.monotonic()
198206
n_warning = 1
199207
while True:
200208
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
@@ -204,19 +212,22 @@ def acquire_read(self):
204212
# this block is either
205213
# (1) not written
206214
# (2) already read by this reader
207-
# try to read the next block
208-
self.current_idx = (self.current_idx +
209-
1) % self.buffer.max_chunks
210-
if self.current_idx == start_index:
211-
# no block found
212-
if time.time(
213-
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
214-
logger.warning(
215-
"No available block found in %s second. ",
216-
VLLM_RINGBUFFER_WARNING_INTERVAL)
217-
n_warning += 1
218-
# wait for a while (0.1 us)
219-
time.sleep(1e-7)
215+
216+
# for readers, `self.current_idx` is the next block to read
217+
# if this block is not ready,
218+
# we need to wait until it is written
219+
220+
# wait for a while
221+
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
222+
223+
# if we wait for a long time, we should warn the user
224+
if time.monotonic(
225+
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
226+
logger.warning(
227+
"No available block found in %s second. ",
228+
VLLM_RINGBUFFER_WARNING_INTERVAL)
229+
n_warning += 1
230+
220231
continue
221232
# found a block that is not read by this reader
222233
# let caller read from the buffer
@@ -226,6 +237,8 @@ def acquire_read(self):
226237
# caller has read from the buffer
227238
# set the read flag
228239
metadata_buffer[self.reader_rank + 1] = 1
240+
self.current_idx = (self.current_idx +
241+
1) % self.buffer.max_chunks
229242
break
230243

231244
def enqueue(self, obj):

0 commit comments

Comments
 (0)