Skip to content

Commit 832ea88

Browse files
authored
[core][distributed] improve shared memory broadcast (#5754)
1 parent 8c00f9c commit 832ea88

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ def __init__(self,
4848
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
4949
+--------------+--------------+--------------+-----+--------------+
5050
51+
The state of metadata is as follows:
52+
53+
(case 1) 0???...???: the block is not written yet, cannot read, can write
54+
(case 2) 1000...000: the block is just written, can read, cannot write
55+
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
56+
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
57+
58+
State transition for readers:
59+
60+
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
61+
Only after the caller finishes reading the block, the reader can mark the block as read.
62+
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
63+
64+
State transition for writer:
65+
66+
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
67+
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
68+
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
69+
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
70+
5171
During creation, `name` is None and the buffer is created. We can pass the
5272
created object to other processes by pickling it. The other processes will
5373
get the name of the shared memory and open it, so that they can access the
@@ -81,10 +101,6 @@ def __init__(self,
81101
lambda *args, **kwargs: None):
82102
self.shared_memory = shared_memory.SharedMemory(name=name)
83103
assert self.shared_memory.size == self.total_bytes_of_buffer
84-
with memoryview(self.shared_memory.buf[self.metadata_offset:]
85-
) as metadata_buffer:
86-
tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)
87-
assert torch.all(tensor == 0)
88104

89105
def __reduce__(self):
90106
return (
@@ -163,11 +179,15 @@ def acquire_write(self):
163179
yield buf
164180

165181
# caller has written to the buffer
166-
# mark the block as written
167-
metadata_buffer[0] = 1
182+
# NOTE: order is important here
183+
# first set the read flags to 0
184+
# then set the written flag to 1
185+
# otherwise, the readers may think they already read the block
168186
for i in range(1, self.buffer.n_reader + 1):
169187
# set read flag to 0, meaning it is not read yet
170188
metadata_buffer[i] = 0
189+
# mark the block as written
190+
metadata_buffer[0] = 1
171191
break
172192

173193
@contextmanager
@@ -247,13 +267,15 @@ def create_from_process_group(pg: ProcessGroup,
247267
buffer: ShmRingBuffer
248268
if group_rank == writer_rank:
249269
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
250-
dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])
251-
dist.barrier(pg)
270+
dist.broadcast_object_list([buffer],
271+
src=global_ranks[writer_rank],
272+
group=pg)
252273
return ShmRingBufferIO(buffer, -1)
253274
else:
254275
recv = [None]
255-
dist.broadcast_object_list(recv, src=global_ranks[writer_rank])
256-
dist.barrier(pg)
276+
dist.broadcast_object_list(recv,
277+
src=global_ranks[writer_rank],
278+
group=pg)
257279
buffer = recv[0] # type: ignore
258280
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
259281
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))

0 commit comments

Comments
 (0)