@@ -43,7 +43,7 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
4343
4444 self .buffer_size = 0
4545 self .buffer_size_threshold = buffer_size_thresh
46- self .buffer_lock = threading .Lock ()
46+ self .buffer_cv = threading .Condition ()
4747 self .signal_pipe = signal_pipe
4848 self .data_pipe = data_pipe
4949 self .request_handling_thread : Optional [threading .Thread ] = None
@@ -116,11 +116,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
116116 hidden = hidden .clone ()
117117
118118 buffer_item = [input_tokens , roi , key , value , hidden ]
119+ data_size = sum ([self ._get_element_size (data ) for data in buffer_item ])
119120
120- with self .buffer_lock :
121- for data in buffer_item :
122- self .buffer_size += self ._get_element_size (data )
121+ with self .buffer_cv :
122+ if self .buffer_size + data_size > self .buffer_size_threshold :
123+ # log outside the while loop to avoid this message being logged
124+ # repeatedly.
125+ logger .debug ("KV transfer buffer is full. Handling..." )
126+ while self .buffer_size + data_size > self .buffer_size_threshold :
127+ self .buffer_cv .wait ()
128+
129+ self .buffer_size += data_size
123130 self .buffer .append (buffer_item )
131+ self .buffer_cv .notify ()
124132
125133 def _is_end_signal (self , signal ):
126134 return signal is None
@@ -143,35 +151,29 @@ def drop_select_handler(self):
143151 roi = (roi > 0.5 )
144152 tokens_roi_recver = [input_tokens , roi ]
145153
146- matched_length = 0
147-
148- # perform input tokens and roi matching
149- # FIXME: this matching is O(n), ideally it should be O(1)
150- # but this buffer size won't (and shouldn't) be too large so
151- # the fix is not urgent.
152- with self .buffer_lock :
153-
154+ def is_buffer_available (
155+ tokens_roi_recver : List [torch .Tensor ], ) -> bool :
156+ # perform input tokens and roi matching
157+ # FIXME: this matching is O(n), ideally it should be O(1)
158+ # but this buffer size won't (and shouldn't) be too large so
159+ # the fix is not urgent.
154160 for _ in range (len (self .buffer )):
155-
156- temp_length = self ._matches (self .buffer [0 ],
157- tokens_roi_recver )
158- if temp_length > 0 :
159- matched_length = temp_length
160- break
161+ if self ._matches (self .buffer [0 ],
162+ tokens_roi_recver ) > 0 :
163+ return True
161164 # rotate the element we just accessed to the end
162165 self .buffer .rotate (- 1 )
163-
164- if matched_length > 0 :
165- # need to clone the tensor
166- # in case the tensor is freed before sending finishes
167- matched_item = self .buffer .popleft ()
168- for tensor in matched_item :
169- self ._send_tensor_and_dec_size (tensor )
170-
171- else :
172- # no match, just send None
173- for _ in range (5 ):
174- self .data_pipe .send_tensor (None )
166+ return False
167+
168+ with self .buffer_cv :
169+ while not is_buffer_available (tokens_roi_recver ):
170+ self .buffer_cv .wait ()
171+ # need to clone the tensor
172+ # in case the tensor is freed before sending finishes
173+ matched_item = self .buffer .popleft ()
174+ for tensor in matched_item :
175+ self ._send_tensor_and_dec_size (tensor )
176+ self .buffer_cv .notify ()
175177
176178 except RuntimeError as e :
177179 if 'Connection closed by peer' not in str (e ):
@@ -215,13 +217,6 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
215217 key : torch .Tensor , value : torch .Tensor ,
216218 hidden : torch .Tensor ) -> None :
217219
218- if self .buffer_size > self .buffer_size_threshold :
219- # log outside the while loop to avoid this message being logged
220- # repeatedly.
221- logger .debug ("KV transfer buffer is full. Handling..." )
222- while self .buffer_size > self .buffer_size_threshold :
223- self .full_handler ()
224-
225220 self ._add_to_buffer (input_tokens , roi , key , value , hidden )
226221
227222 # when calling the insert, the current process is a sender
0 commit comments