|
| 1 | +import pickle |
| 2 | +import time |
| 3 | +from contextlib import contextmanager |
| 4 | +from multiprocessing import shared_memory |
| 5 | +from typing import Optional |
| 6 | +from unittest.mock import patch |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.distributed as dist |
| 10 | +from torch.distributed import ProcessGroup |
| 11 | + |
| 12 | +import vllm.envs as envs |
| 13 | +from vllm.logger import init_logger |
| 14 | + |
| 15 | +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL |
| 16 | + |
| 17 | +logger = init_logger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class ShmRingBuffer: |
| 21 | + |
| 22 | + def __init__(self, |
| 23 | + n_reader: int, |
| 24 | + max_chunk_bytes: int, |
| 25 | + max_chunks: int, |
| 26 | + name: Optional[str] = None): |
| 27 | + """ |
| 28 | + A shared memory ring buffer implementation for broadcast communication. |
| 29 | + Essentially, it is a queue where only one will `enqueue` and multiple |
| 30 | + will `dequeue`. The max size of each item, together with the max number |
| 31 | + of items that can be stored in the buffer are known in advance. |
| 32 | + In this case, we don't need to synchronize the access to |
| 33 | + the buffer. |
| 34 | + |
| 35 | + Buffer memory layout: |
| 36 | + data metadata |
| 37 | + | | |
| 38 | + | (current_idx) | (current_idx) |
| 39 | + v v |
| 40 | + +-------------------------------+----------------------------------------+ |
| 41 | + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | |
| 42 | + +-------------------------------+----------------------------------------+ |
| 43 | + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | |
| 44 | +
|
| 45 | + metadata memory layout: each byte is a flag, the first byte is the written |
| 46 | + flag, and the rest are reader flags. The flags are set to 0 by default. |
| 47 | + +--------------+--------------+--------------+-----+--------------+ |
| 48 | + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | |
| 49 | + +--------------+--------------+--------------+-----+--------------+ |
| 50 | +
|
| 51 | + During creation, `name` is None and the buffer is created. We can pass the |
| 52 | + created object to other processes by pickling it. The other processes will |
| 53 | + get the name of the shared memory and open it, so that they can access the |
| 54 | + same shared memory buffer. |
| 55 | + """# noqa |
| 56 | + self.n_reader = n_reader |
| 57 | + self.metadata_size = 1 + n_reader |
| 58 | + self.max_chunk_bytes = max_chunk_bytes |
| 59 | + self.max_chunks = max_chunks |
| 60 | + self.total_bytes_of_buffer = (self.max_chunk_bytes + |
| 61 | + self.metadata_size) * self.max_chunks |
| 62 | + self.data_offset = 0 |
| 63 | + self.metadata_offset = self.max_chunk_bytes * self.max_chunks |
| 64 | + |
| 65 | + if name is None: |
| 66 | + # we are creating a buffer |
| 67 | + self.is_creator = True |
| 68 | + self.shared_memory = shared_memory.SharedMemory( |
| 69 | + create=True, size=self.total_bytes_of_buffer) |
| 70 | + # initialize the metadata section to 0 |
| 71 | + with memoryview(self.shared_memory.buf[self.metadata_offset:] |
| 72 | + ) as metadata_buffer: |
| 73 | + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) |
| 74 | + else: |
| 75 | + # we are opening an existing buffer |
| 76 | + self.is_creator = False |
| 77 | + # fix to https://stackoverflow.com/q/62748654/9191338 |
| 78 | + # Python incorrectly tracks shared memory even if it is not |
| 79 | + # created by the process. The following patch is a workaround. |
| 80 | + with patch("multiprocessing.resource_tracker.register", |
| 81 | + lambda *args, **kwargs: None): |
| 82 | + self.shared_memory = shared_memory.SharedMemory(name=name) |
| 83 | + assert self.shared_memory.size == self.total_bytes_of_buffer |
| 84 | + with memoryview(self.shared_memory.buf[self.metadata_offset:] |
| 85 | + ) as metadata_buffer: |
| 86 | + tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8) |
| 87 | + assert torch.all(tensor == 0) |
| 88 | + |
| 89 | + def __reduce__(self): |
| 90 | + return ( |
| 91 | + self.__class__, |
| 92 | + (self.n_reader, self.max_chunk_bytes, self.max_chunks, |
| 93 | + self.shared_memory.name), |
| 94 | + ) |
| 95 | + |
| 96 | + def __del__(self): |
| 97 | + self.shared_memory.close() |
| 98 | + if self.is_creator: |
| 99 | + self.shared_memory.unlink() |
| 100 | + |
| 101 | + @contextmanager |
| 102 | + def get_data(self, current_idx: int): |
| 103 | + start = self.data_offset + current_idx * self.max_chunk_bytes |
| 104 | + end = start + self.max_chunk_bytes |
| 105 | + with memoryview(self.shared_memory.buf[start:end]) as buf: |
| 106 | + yield buf |
| 107 | + |
| 108 | + @contextmanager |
| 109 | + def get_metadata(self, current_idx: int): |
| 110 | + start = self.metadata_offset + current_idx * self.metadata_size |
| 111 | + end = start + self.metadata_size |
| 112 | + with memoryview(self.shared_memory.buf[start:end]) as buf: |
| 113 | + yield buf |
| 114 | + |
| 115 | + |
| 116 | +class ShmRingBufferIO: |
| 117 | + |
| 118 | + def __init__(self, buffer: ShmRingBuffer, reader_rank: int): |
| 119 | + self.buffer = buffer |
| 120 | + self.reader_rank = reader_rank |
| 121 | + self._is_writer = self.reader_rank == -1 |
| 122 | + self._is_reader = not self._is_writer |
| 123 | + if self._is_reader: |
| 124 | + assert 0 <= self.reader_rank < buffer.n_reader, \ |
| 125 | + (f"Invalid reader rank {self.reader_rank} for buffer" |
| 126 | + f" created with {buffer.n_reader} readers") |
| 127 | + self.current_idx = 0 |
| 128 | + |
| 129 | + @contextmanager |
| 130 | + def acquire_write(self): |
| 131 | + assert self._is_writer, "Only writers can acquire write" |
| 132 | + start_index = self.current_idx |
| 133 | + start_time = time.time() |
| 134 | + n_warning = 1 |
| 135 | + while True: |
| 136 | + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: |
| 137 | + read_count = sum(metadata_buffer[1:]) |
| 138 | + written_flag = metadata_buffer[0] |
| 139 | + if written_flag and read_count != self.buffer.n_reader: |
| 140 | + # this block is written and not read by all readers |
| 141 | + # try to write to the next block |
| 142 | + self.current_idx = (self.current_idx + |
| 143 | + 1) % self.buffer.max_chunks |
| 144 | + if self.current_idx == start_index: |
| 145 | + # no empty block found |
| 146 | + if time.time( |
| 147 | + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa |
| 148 | + logger.warning( |
| 149 | + "No available block found in %s second. ", |
| 150 | + VLLM_RINGBUFFER_WARNING_INTERVAL) |
| 151 | + n_warning += 1 |
| 152 | + # wait for a while (0.1 us) |
| 153 | + time.sleep(1e-7) |
| 154 | + continue |
| 155 | + # found a block that is either |
| 156 | + # (1) not written |
| 157 | + # (2) read by all readers |
| 158 | + |
| 159 | + # mark the block as not written |
| 160 | + metadata_buffer[0] = 0 |
| 161 | + # let caller write to the buffer |
| 162 | + with self.buffer.get_data(self.current_idx) as buf: |
| 163 | + yield buf |
| 164 | + |
| 165 | + # caller has written to the buffer |
| 166 | + # mark the block as written |
| 167 | + metadata_buffer[0] = 1 |
| 168 | + for i in range(1, self.buffer.n_reader + 1): |
| 169 | + # set read flag to 0, meaning it is not read yet |
| 170 | + metadata_buffer[i] = 0 |
| 171 | + break |
| 172 | + |
| 173 | + @contextmanager |
| 174 | + def acquire_read(self): |
| 175 | + assert self._is_reader, "Only readers can acquire read" |
| 176 | + start_index = self.current_idx |
| 177 | + start_time = time.time() |
| 178 | + n_warning = 1 |
| 179 | + while True: |
| 180 | + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: |
| 181 | + read_flag = metadata_buffer[self.reader_rank + 1] |
| 182 | + written_flag = metadata_buffer[0] |
| 183 | + if not written_flag or read_flag: |
| 184 | + # this block is either |
| 185 | + # (1) not written |
| 186 | + # (2) already read by this reader |
| 187 | + # try to read the next block |
| 188 | + self.current_idx = (self.current_idx + |
| 189 | + 1) % self.buffer.max_chunks |
| 190 | + if self.current_idx == start_index: |
| 191 | + # no block found |
| 192 | + if time.time( |
| 193 | + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa |
| 194 | + logger.warning( |
| 195 | + "No available block found in %s second. ", |
| 196 | + VLLM_RINGBUFFER_WARNING_INTERVAL) |
| 197 | + n_warning += 1 |
| 198 | + # wait for a while (0.1 us) |
| 199 | + time.sleep(1e-7) |
| 200 | + continue |
| 201 | + # found a block that is not read by this reader |
| 202 | + # let caller read from the buffer |
| 203 | + with self.buffer.get_data(self.current_idx) as buf: |
| 204 | + yield buf |
| 205 | + |
| 206 | + # caller has read from the buffer |
| 207 | + # set the read flag |
| 208 | + metadata_buffer[self.reader_rank + 1] = 1 |
| 209 | + break |
| 210 | + |
| 211 | + def enqueue(self, obj): |
| 212 | + assert self._is_writer, "Only writers can enqueue" |
| 213 | + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) |
| 214 | + if len(serialized_obj) > self.buffer.max_chunk_bytes: |
| 215 | + raise RuntimeError( |
| 216 | + f"{len(serialized_obj)=} larger than the allowed value " |
| 217 | + f"{self.buffer.max_chunk_bytes}," |
| 218 | + "Please increase the max_chunk_bytes parameter.") |
| 219 | + with self.acquire_write() as buf: |
| 220 | + buf[:len(serialized_obj)] = serialized_obj |
| 221 | + |
| 222 | + def dequeue(self): |
| 223 | + assert self._is_reader, "Only readers can dequeue" |
| 224 | + with self.acquire_read() as buf: |
| 225 | + # no need to know the size of serialized object |
| 226 | + # pickle format itself contains the size information internally |
| 227 | + # see https://docs.python.org/3/library/pickle.html |
| 228 | + obj = pickle.loads(buf) |
| 229 | + return obj |
| 230 | + |
| 231 | + def broadcast_object(self, obj=None): |
| 232 | + if self._is_writer: |
| 233 | + self.enqueue(obj) |
| 234 | + return obj |
| 235 | + else: |
| 236 | + return self.dequeue() |
| 237 | + |
| 238 | + def create_from_process_group(pg: ProcessGroup, |
| 239 | + max_chunk_bytes, |
| 240 | + max_chunks, |
| 241 | + writer_rank=0) -> "ShmRingBufferIO": |
| 242 | + group_rank = dist.get_rank(pg) |
| 243 | + group_world_size = dist.get_world_size(pg) |
| 244 | + ranks_inside_group = list(range(group_world_size)) |
| 245 | + global_ranks = dist.get_process_group_ranks(pg) |
| 246 | + n_reader = group_world_size - 1 |
| 247 | + buffer: ShmRingBuffer |
| 248 | + if group_rank == writer_rank: |
| 249 | + buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks) |
| 250 | + dist.broadcast_object_list([buffer], src=global_ranks[writer_rank]) |
| 251 | + dist.barrier(pg) |
| 252 | + return ShmRingBufferIO(buffer, -1) |
| 253 | + else: |
| 254 | + recv = [None] |
| 255 | + dist.broadcast_object_list(recv, src=global_ranks[writer_rank]) |
| 256 | + dist.barrier(pg) |
| 257 | + buffer = recv[0] # type: ignore |
| 258 | + rest_ranks = [r for r in ranks_inside_group if r != writer_rank] |
| 259 | + return ShmRingBufferIO(buffer, rest_ranks.index(group_rank)) |
0 commit comments