@@ -29,6 +29,9 @@ class RequestStatus(enum.Enum):
2929 COMPLETED = enum .auto ()
3030 LENGTH_CAPPED = enum .auto ()
3131
32+ # recycle status
33+ RECYCLED = enum .auto ()
34+
3235 @staticmethod
3336 def is_finished (status : "RequestStatus" ) -> bool :
3437 return status in [
@@ -119,7 +122,9 @@ def mark_running(self) -> None:
119122 """
120123 Set status for prefill reqs.
121124 """
122- assert self .status == RequestStatus .WAITING , "Sequence is not in WAITTING STATUS"
125+ assert (
126+ self .status == RequestStatus .WAITING or RequestStatus .RECYCLED
127+ ), "Sequence is not in WAITTING/RECYCLED STATUS"
123128 self .status = RequestStatus .RUNNING
124129
125130 def mark_finished (self ) -> None :
@@ -139,10 +144,10 @@ def recycle(self) -> None:
139144 Recycle a running sequnce to waiitting list
140145 """
141146 assert (
142- not self .status . is_finished and not self .status == RequestStatus .ABORTED
147+ not self .check_finish () and not self .status == RequestStatus .ABORTED
143148 ), "The running sequence \
144149 is already done but it still in running list"
145- self .status = RequestStatus .WAITING
150+ self .status = RequestStatus .RECYCLED
146151
147152 def __repr__ (self ) -> str :
148153 return (
@@ -162,7 +167,7 @@ class BatchInfo:
162167 Information to be passed and used for a batch of sequences.
163168 """
164169
165- sequences_set : OrderedSet [" Sequence" ] = None
170+ sequences_set : OrderedSet [Sequence ] = None
166171 is_prompts : bool = True
167172 device : torch .device = None
168173
@@ -207,12 +212,20 @@ def get_block_table_tensor(self) -> None:
207212
208213 def clear_batch (self ) -> None :
209214 """
210- Clear sequence set and block table.
215+ Clear sequence set and block table if we need to abort this batch.
216+ Prefill: clear sequence set and move them to running batch(external)
217+ Decoding: mark unfinished sequences as aborted.
211218 """
212- for seq in self .sequences_set :
213- if not seq .check_finish ():
214- seq .status = RequestStatus .ABORTED
215- self .sequences_set .clear ()
219+ if self .is_prompts :
220+ self .sequences_set .clear ()
221+
222+ else :
223+ for seq in self .sequences_set :
224+ seq .mark_aborted ()
225+ if seq .check_finish ():
226+ seq .mark_finished ()
227+
228+ self .sequences_set .clear ()
216229
217230 def fliter_batch (self ) -> List ["Sequence" ]:
218231 """
@@ -255,6 +268,12 @@ def add_seqs(self, seqs: List["Sequence"]) -> None:
255268 continue
256269 self .sequences_set .add (seq )
257270
271+ def del_seq (self , seq : Sequence ) -> Sequence :
272+ """
273+ Delete sequence in batch
274+ """
275+ self .sequences_set .discard (seq )
276+
258277 @property
259278 def is_empty (self ) -> None :
260279 """
@@ -297,11 +316,19 @@ def get_batch_inputs(self) -> torch.LongTensor:
297316
298317 for seq in self .sequences_set :
299318 if self .is_prompts :
300- input_list .append (seq .input_token_id )
319+ if seq .output_len > 0 :
320+ print (seq .output_token_id )
321+ seq_data = seq .input_token_id + seq .output_token_id
322+ print (seq_data )
323+ input_list .append (seq .input_token_id + seq .output_token_id )
324+ else :
325+ input_list .append (seq .input_token_id )
301326 else :
302327 input_list .append ([seq .output_token_id [- 1 ]])
303328
304- return torch .tensor (input_list , dtype = torch .long , device = self .device )
329+ max_seq_len = max (len (sub_list ) for sub_list in input_list )
330+
331+ return _make_tensor_with_pad (input_list , max_seq_len , 0 , dtype = torch .int )
305332
306333 def get_1D_inputs (self ) -> Tuple [torch .LongTensor , torch .Tensor ]:
307334 """
@@ -340,12 +367,27 @@ def get_attn_mask(self, padding_id: int) -> torch.Tensor:
340367 for seq in self .sequences_set :
341368 past_values .append (seq .input_token_id + seq .output_token_id )
342369
343- attn_mask = torch .tensor (past_values , dtype = torch .int , device = self .device ).ne (padding_id ).long ()
370+ max_seq_len = max (len (sub_list ) for sub_list in past_values )
371+ attn_mask = _make_tensor_with_pad (past_values , max_seq_len , 0 , dtype = torch .int , device = self .device )
344372
345- if torch .any (attn_mask == 0 ):
346- return attn_mask
347- else :
348- return None
373+ return attn_mask .ne (padding_id ).long ()
349374
350375 def __repr__ (self ) -> str :
351376 return f"(sequences_set={ self .sequences_set } , " f"is_prompts={ self .is_prompts } )"
377+
378+
379+ def _pad_to_max (x : List [int ], max_len : int , pad : int ) -> List [int ]:
380+ assert len (x ) <= max_len
381+ return x + [pad ] * (max_len - len (x ))
382+
383+
384+ def _make_tensor_with_pad (
385+ x : Union [List [List [int ]], List [int ]],
386+ max_len : int ,
387+ pad : int ,
388+ dtype : torch .dtype ,
389+ device : Union [str , torch .device ] = "cuda" ,
390+ pin_memory : bool = False ,
391+ ):
392+ padded_x = [_pad_to_max (x_i , max_len , pad ) for x_i in x ]
393+ return torch .tensor (padded_x , dtype = dtype , device = device , pin_memory = pin_memory and str (device ) == "cpu" )
0 commit comments