Skip to content

Commit 4779f98

Browse files
committed
Fix disagg hang caused by the prefill and decode communication issues
Signed-off-by: Lu Fang <lufang@fb.com>
1 parent 5d98d56 commit 4779f98

File tree

1 file changed

+32
-37
lines changed

1 file changed

+32
-37
lines changed

vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
4343

4444
self.buffer_size = 0
4545
self.buffer_size_threshold = buffer_size_thresh
46-
self.buffer_lock = threading.Lock()
46+
self.buffer_cv = threading.Condition()
4747
self.signal_pipe = signal_pipe
4848
self.data_pipe = data_pipe
4949
self.request_handling_thread: Optional[threading.Thread] = None
@@ -116,11 +116,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
116116
hidden = hidden.clone()
117117

118118
buffer_item = [input_tokens, roi, key, value, hidden]
119+
data_size = sum([self._get_element_size(data) for data in buffer_item])
119120

120-
with self.buffer_lock:
121-
for data in buffer_item:
122-
self.buffer_size += self._get_element_size(data)
121+
with self.buffer_cv:
122+
if self.buffer_size + data_size > self.buffer_size_threshold:
123+
# log outside the while loop to avoid this message being logged
124+
# repeatedly.
125+
logger.debug("KV transfer buffer is full. Handling...")
126+
while self.buffer_size + data_size > self.buffer_size_threshold:
127+
self.buffer_cv.wait()
128+
129+
self.buffer_size += data_size
123130
self.buffer.append(buffer_item)
131+
self.buffer_cv.notify()
124132

125133
def _is_end_signal(self, signal):
126134
return signal is None
@@ -143,35 +151,29 @@ def drop_select_handler(self):
143151
roi = (roi > 0.5)
144152
tokens_roi_recver = [input_tokens, roi]
145153

146-
matched_length = 0
147-
148-
# perform input tokens and roi matching
149-
# FIXME: this matching is O(n), ideally it should be O(1)
150-
# but this buffer size won't (and shouldn't) be too large so
151-
# the fix is not urgent.
152-
with self.buffer_lock:
153-
154+
def is_buffer_available(
155+
tokens_roi_recver: List[torch.Tensor], ) -> bool:
156+
# perform input tokens and roi matching
157+
# FIXME: this matching is O(n), ideally it should be O(1)
158+
# but this buffer size won't (and shouldn't) be too large so
159+
# the fix is not urgent.
154160
for _ in range(len(self.buffer)):
155-
156-
temp_length = self._matches(self.buffer[0],
157-
tokens_roi_recver)
158-
if temp_length > 0:
159-
matched_length = temp_length
160-
break
161+
if self._matches(self.buffer[0],
162+
tokens_roi_recver) > 0:
163+
return True
161164
# rotate the element we just accessed to the end
162165
self.buffer.rotate(-1)
163-
164-
if matched_length > 0:
165-
# need to clone the tensor
166-
# in case the tensor is freed before sending finishes
167-
matched_item = self.buffer.popleft()
168-
for tensor in matched_item:
169-
self._send_tensor_and_dec_size(tensor)
170-
171-
else:
172-
# no match, just send None
173-
for _ in range(5):
174-
self.data_pipe.send_tensor(None)
166+
return False
167+
168+
with self.buffer_cv:
169+
while not is_buffer_available(tokens_roi_recver):
170+
self.buffer_cv.wait()
171+
# need to clone the tensor
172+
# in case the tensor is freed before sending finishes
173+
matched_item = self.buffer.popleft()
174+
for tensor in matched_item:
175+
self._send_tensor_and_dec_size(tensor)
176+
self.buffer_cv.notify()
175177

176178
except RuntimeError as e:
177179
if 'Connection closed by peer' not in str(e):
@@ -215,13 +217,6 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
215217
key: torch.Tensor, value: torch.Tensor,
216218
hidden: torch.Tensor) -> None:
217219

218-
if self.buffer_size > self.buffer_size_threshold:
219-
# log outside the while loop to avoid this message being logged
220-
# repeatedly.
221-
logger.debug("KV transfer buffer is full. Handling...")
222-
while self.buffer_size > self.buffer_size_threshold:
223-
self.full_handler()
224-
225220
self._add_to_buffer(input_tokens, roi, key, value, hidden)
226221

227222
# when calling the insert, the current process is a sender

0 commit comments

Comments
 (0)