Skip to content

Commit c08f3c5

Browse files
youkaichaocomaniac
authored andcommitted
[Core][Distributed] add shm broadcast (vllm-project#5399)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
1 parent 84cb37a commit c08f3c5

File tree

5 files changed

+384
-10
lines changed

5 files changed

+384
-10
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ steps:
2828

2929
- label: Distributed Comm Ops Test
3030
#mirror_hardwares: [amd]
31-
command: pytest -v -s distributed/test_comm_ops.py
3231
working_dir: "/vllm-workspace/tests"
3332
num_gpus: 2
33+
commands:
34+
- pytest -v -s distributed/test_comm_ops.py
35+
- pytest -v -s distributed/test_shm_broadcast.py
3436

3537
- label: Distributed Tests (2 GPUs)
3638
mirror_hardwares: [amd]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import multiprocessing
2+
import random
3+
import time
4+
5+
import torch.distributed as dist
6+
7+
from vllm.distributed.device_communicators.shm_broadcast import (
8+
ShmRingBuffer, ShmRingBufferIO)
9+
from vllm.utils import update_environment_variables
10+
11+
12+
def distributed_run(fn, world_size):
13+
number_of_processes = world_size
14+
processes = []
15+
for i in range(number_of_processes):
16+
env = {}
17+
env['RANK'] = str(i)
18+
env['LOCAL_RANK'] = str(i)
19+
env['WORLD_SIZE'] = str(number_of_processes)
20+
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
21+
env['MASTER_ADDR'] = 'localhost'
22+
env['MASTER_PORT'] = '12345'
23+
p = multiprocessing.Process(target=fn, args=(env, ))
24+
processes.append(p)
25+
p.start()
26+
27+
for p in processes:
28+
p.join()
29+
30+
for p in processes:
31+
assert p.exitcode == 0
32+
33+
34+
def worker_fn_wrapper(fn):
35+
# `multiprocessing.Process` cannot accept environment variables directly
36+
# so we need to pass the environment variables as arguments
37+
# and update the environment variables in the function
38+
def wrapped_fn(env):
39+
update_environment_variables(env)
40+
dist.init_process_group(backend="gloo")
41+
fn()
42+
43+
return wrapped_fn
44+
45+
46+
@worker_fn_wrapper
47+
def worker_fn():
48+
writer_rank = 2
49+
broadcaster = ShmRingBufferIO.create_from_process_group(
50+
dist.group.WORLD, 1024, 2, writer_rank)
51+
if dist.get_rank() == writer_rank:
52+
time.sleep(random.random())
53+
broadcaster.broadcast_object(0)
54+
time.sleep(random.random())
55+
broadcaster.broadcast_object({})
56+
time.sleep(random.random())
57+
broadcaster.broadcast_object([])
58+
else:
59+
time.sleep(random.random())
60+
a = broadcaster.broadcast_object(None)
61+
time.sleep(random.random())
62+
b = broadcaster.broadcast_object(None)
63+
time.sleep(random.random())
64+
c = broadcaster.broadcast_object(None)
65+
assert a == 0
66+
assert b == {}
67+
assert c == []
68+
dist.barrier()
69+
70+
71+
def test_shm_broadcast():
72+
distributed_run(worker_fn, 4)
73+
74+
75+
def test_singe_process():
76+
buffer = ShmRingBuffer(1, 1024, 4)
77+
reader = ShmRingBufferIO(buffer, reader_rank=0)
78+
writer = ShmRingBufferIO(buffer, reader_rank=-1)
79+
writer.enqueue([0])
80+
writer.enqueue([1])
81+
assert reader.dequeue() == [0]
82+
assert reader.dequeue() == [1]
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import pickle
2+
import time
3+
from contextlib import contextmanager
4+
from multiprocessing import shared_memory
5+
from typing import Optional
6+
from unittest.mock import patch
7+
8+
import torch
9+
import torch.distributed as dist
10+
from torch.distributed import ProcessGroup
11+
12+
import vllm.envs as envs
13+
from vllm.logger import init_logger
14+
15+
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
16+
17+
logger = init_logger(__name__)
18+
19+
20+
class ShmRingBuffer:
21+
22+
def __init__(self,
23+
n_reader: int,
24+
max_chunk_bytes: int,
25+
max_chunks: int,
26+
name: Optional[str] = None):
27+
"""
28+
A shared memory ring buffer implementation for broadcast communication.
29+
Essentially, it is a queue where only one will `enqueue` and multiple
30+
will `dequeue`. The max size of each item, together with the max number
31+
of items that can be stored in the buffer are known in advance.
32+
In this case, we don't need to synchronize the access to
33+
the buffer.
34+
35+
Buffer memory layout:
36+
data metadata
37+
| |
38+
| (current_idx) | (current_idx)
39+
v v
40+
+-------------------------------+----------------------------------------+
41+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
42+
+-------------------------------+----------------------------------------+
43+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
44+
45+
metadata memory layout: each byte is a flag, the first byte is the written
46+
flag, and the rest are reader flags. The flags are set to 0 by default.
47+
+--------------+--------------+--------------+-----+--------------+
48+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
49+
+--------------+--------------+--------------+-----+--------------+
50+
51+
During creation, `name` is None and the buffer is created. We can pass the
52+
created object to other processes by pickling it. The other processes will
53+
get the name of the shared memory and open it, so that they can access the
54+
same shared memory buffer.
55+
"""# noqa
56+
self.n_reader = n_reader
57+
self.metadata_size = 1 + n_reader
58+
self.max_chunk_bytes = max_chunk_bytes
59+
self.max_chunks = max_chunks
60+
self.total_bytes_of_buffer = (self.max_chunk_bytes +
61+
self.metadata_size) * self.max_chunks
62+
self.data_offset = 0
63+
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
64+
65+
if name is None:
66+
# we are creating a buffer
67+
self.is_creator = True
68+
self.shared_memory = shared_memory.SharedMemory(
69+
create=True, size=self.total_bytes_of_buffer)
70+
# initialize the metadata section to 0
71+
with memoryview(self.shared_memory.buf[self.metadata_offset:]
72+
) as metadata_buffer:
73+
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
74+
else:
75+
# we are opening an existing buffer
76+
self.is_creator = False
77+
# fix to https://stackoverflow.com/q/62748654/9191338
78+
# Python incorrectly tracks shared memory even if it is not
79+
# created by the process. The following patch is a workaround.
80+
with patch("multiprocessing.resource_tracker.register",
81+
lambda *args, **kwargs: None):
82+
self.shared_memory = shared_memory.SharedMemory(name=name)
83+
assert self.shared_memory.size == self.total_bytes_of_buffer
84+
with memoryview(self.shared_memory.buf[self.metadata_offset:]
85+
) as metadata_buffer:
86+
tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)
87+
assert torch.all(tensor == 0)
88+
89+
def __reduce__(self):
90+
return (
91+
self.__class__,
92+
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
93+
self.shared_memory.name),
94+
)
95+
96+
def __del__(self):
97+
self.shared_memory.close()
98+
if self.is_creator:
99+
self.shared_memory.unlink()
100+
101+
@contextmanager
102+
def get_data(self, current_idx: int):
103+
start = self.data_offset + current_idx * self.max_chunk_bytes
104+
end = start + self.max_chunk_bytes
105+
with memoryview(self.shared_memory.buf[start:end]) as buf:
106+
yield buf
107+
108+
@contextmanager
109+
def get_metadata(self, current_idx: int):
110+
start = self.metadata_offset + current_idx * self.metadata_size
111+
end = start + self.metadata_size
112+
with memoryview(self.shared_memory.buf[start:end]) as buf:
113+
yield buf
114+
115+
116+
class ShmRingBufferIO:
117+
118+
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
119+
self.buffer = buffer
120+
self.reader_rank = reader_rank
121+
self._is_writer = self.reader_rank == -1
122+
self._is_reader = not self._is_writer
123+
if self._is_reader:
124+
assert 0 <= self.reader_rank < buffer.n_reader, \
125+
(f"Invalid reader rank {self.reader_rank} for buffer"
126+
f" created with {buffer.n_reader} readers")
127+
self.current_idx = 0
128+
129+
@contextmanager
130+
def acquire_write(self):
131+
assert self._is_writer, "Only writers can acquire write"
132+
start_index = self.current_idx
133+
start_time = time.time()
134+
n_warning = 1
135+
while True:
136+
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
137+
read_count = sum(metadata_buffer[1:])
138+
written_flag = metadata_buffer[0]
139+
if written_flag and read_count != self.buffer.n_reader:
140+
# this block is written and not read by all readers
141+
# try to write to the next block
142+
self.current_idx = (self.current_idx +
143+
1) % self.buffer.max_chunks
144+
if self.current_idx == start_index:
145+
# no empty block found
146+
if time.time(
147+
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
148+
logger.warning(
149+
"No available block found in %s second. ",
150+
VLLM_RINGBUFFER_WARNING_INTERVAL)
151+
n_warning += 1
152+
# wait for a while (0.1 us)
153+
time.sleep(1e-7)
154+
continue
155+
# found a block that is either
156+
# (1) not written
157+
# (2) read by all readers
158+
159+
# mark the block as not written
160+
metadata_buffer[0] = 0
161+
# let caller write to the buffer
162+
with self.buffer.get_data(self.current_idx) as buf:
163+
yield buf
164+
165+
# caller has written to the buffer
166+
# mark the block as written
167+
metadata_buffer[0] = 1
168+
for i in range(1, self.buffer.n_reader + 1):
169+
# set read flag to 0, meaning it is not read yet
170+
metadata_buffer[i] = 0
171+
break
172+
173+
@contextmanager
174+
def acquire_read(self):
175+
assert self._is_reader, "Only readers can acquire read"
176+
start_index = self.current_idx
177+
start_time = time.time()
178+
n_warning = 1
179+
while True:
180+
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
181+
read_flag = metadata_buffer[self.reader_rank + 1]
182+
written_flag = metadata_buffer[0]
183+
if not written_flag or read_flag:
184+
# this block is either
185+
# (1) not written
186+
# (2) already read by this reader
187+
# try to read the next block
188+
self.current_idx = (self.current_idx +
189+
1) % self.buffer.max_chunks
190+
if self.current_idx == start_index:
191+
# no block found
192+
if time.time(
193+
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
194+
logger.warning(
195+
"No available block found in %s second. ",
196+
VLLM_RINGBUFFER_WARNING_INTERVAL)
197+
n_warning += 1
198+
# wait for a while (0.1 us)
199+
time.sleep(1e-7)
200+
continue
201+
# found a block that is not read by this reader
202+
# let caller read from the buffer
203+
with self.buffer.get_data(self.current_idx) as buf:
204+
yield buf
205+
206+
# caller has read from the buffer
207+
# set the read flag
208+
metadata_buffer[self.reader_rank + 1] = 1
209+
break
210+
211+
def enqueue(self, obj):
212+
assert self._is_writer, "Only writers can enqueue"
213+
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
214+
if len(serialized_obj) > self.buffer.max_chunk_bytes:
215+
raise RuntimeError(
216+
f"{len(serialized_obj)=} larger than the allowed value "
217+
f"{self.buffer.max_chunk_bytes},"
218+
"Please increase the max_chunk_bytes parameter.")
219+
with self.acquire_write() as buf:
220+
buf[:len(serialized_obj)] = serialized_obj
221+
222+
def dequeue(self):
223+
assert self._is_reader, "Only readers can dequeue"
224+
with self.acquire_read() as buf:
225+
# no need to know the size of serialized object
226+
# pickle format itself contains the size information internally
227+
# see https://docs.python.org/3/library/pickle.html
228+
obj = pickle.loads(buf)
229+
return obj
230+
231+
def broadcast_object(self, obj=None):
232+
if self._is_writer:
233+
self.enqueue(obj)
234+
return obj
235+
else:
236+
return self.dequeue()
237+
238+
def create_from_process_group(pg: ProcessGroup,
239+
max_chunk_bytes,
240+
max_chunks,
241+
writer_rank=0) -> "ShmRingBufferIO":
242+
group_rank = dist.get_rank(pg)
243+
group_world_size = dist.get_world_size(pg)
244+
ranks_inside_group = list(range(group_world_size))
245+
global_ranks = dist.get_process_group_ranks(pg)
246+
n_reader = group_world_size - 1
247+
buffer: ShmRingBuffer
248+
if group_rank == writer_rank:
249+
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
250+
dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])
251+
dist.barrier(pg)
252+
return ShmRingBufferIO(buffer, -1)
253+
else:
254+
recv = [None]
255+
dist.broadcast_object_list(recv, src=global_ranks[writer_rank])
256+
dist.barrier(pg)
257+
buffer = recv[0] # type: ignore
258+
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
259+
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))

0 commit comments

Comments
 (0)