|
1 | 1 | import enum |
2 | 2 | from dataclasses import dataclass |
3 | | -from typing import Any, List, Tuple, Union |
| 3 | +from typing import Any, List |
4 | 4 |
|
5 | | -import torch |
6 | | -from ordered_set import OrderedSet |
7 | | - |
8 | | -from colossalai.inference.flash_decoding_utils import FDIntermTensors |
9 | 5 | from colossalai.logging import get_dist_logger |
10 | 6 |
|
11 | 7 | logger = get_dist_logger(__name__) |
@@ -170,242 +166,6 @@ def __repr__(self) -> str: |
170 | 166 | ) |
171 | 167 |
|
172 | 168 |
|
173 | | -@dataclass |
174 | | -class BatchInfo: |
175 | | - """ |
176 | | - Information to be passed and used for a batch of sequences. |
177 | | - """ |
178 | | - |
179 | | - max_batch_size: int |
180 | | - kv_max_split_num: int |
181 | | - num_heads: int |
182 | | - head_dim: int |
183 | | - sequences_set: OrderedSet[Sequence] = None |
184 | | - is_prompts: bool = True |
185 | | - device: torch.device = None |
186 | | - dtype: torch.dtype = None |
187 | | - fd_inter_tensor: FDIntermTensors = None |
188 | | - |
189 | | - def __post_init__(self): |
190 | | - if self.device is None: |
191 | | - self.device = torch.cuda.current_device() |
192 | | - if self.sequences_set is None: |
193 | | - self.sequences_set = OrderedSet() |
194 | | - if self.fd_inter_tensor is None: |
195 | | - self.fd_inter_tensor = FDIntermTensors() |
196 | | - |
197 | | - def init_fd_tensors(self): |
198 | | - if not self.fd_inter_tensor.is_initialized: |
199 | | - self.fd_inter_tensor.initialize( |
200 | | - max_batch_size=self.max_batch_size, |
201 | | - num_attn_heads=self.num_heads, |
202 | | - kv_max_split_num=self.kv_max_split_num, |
203 | | - head_dim=self.head_dim, |
204 | | - dtype=self.dtype, |
205 | | - device=self.device, |
206 | | - ) |
207 | | - |
208 | | - def get_block_table_tensor(self) -> None: |
209 | | - tesnor_list = [] |
210 | | - block_table = None |
211 | | - |
212 | | - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." |
213 | | - |
214 | | - for seq in self.sequences_set: |
215 | | - block_table = seq.block_table |
216 | | - assert ( |
217 | | - block_table is not None |
218 | | - ), f"The sequence(request_id {seq.request_id}) has not initialized the block_table." |
219 | | - tesnor_list.append(seq.block_table) |
220 | | - |
221 | | - block_table = torch.stack(tesnor_list) |
222 | | - return block_table |
223 | | - |
224 | | - def clear_batch(self) -> None: |
225 | | - """ |
226 | | - Clear sequence set and block table if we need to abort this batch. |
227 | | - Prefill: clear sequence set and move them to running batch(external) |
228 | | - Decoding: mark unfinished sequences as aborted. |
229 | | - """ |
230 | | - if self.is_prompts: |
231 | | - self.sequences_set.clear() |
232 | | - else: |
233 | | - for seq in self.sequences_set: |
234 | | - seq.mark_aborted() |
235 | | - if seq.check_finish(): |
236 | | - seq.mark_finished() |
237 | | - |
238 | | - self.sequences_set.clear() |
239 | | - |
240 | | - def fliter_batch(self) -> List["Sequence"]: |
241 | | - """ |
242 | | - Remove completed sentences from a batch. |
243 | | -
|
244 | | - Returns: |
245 | | - List["Sequence"]: List of finished sequences. |
246 | | - """ |
247 | | - finish_seqs = [] |
248 | | - for seq in self.sequences_set: |
249 | | - if seq.check_finish(): |
250 | | - finish_seqs.append(seq) |
251 | | - for finish_seq in finish_seqs: |
252 | | - self.sequences_set.discard(finish_seq) |
253 | | - return finish_seqs |
254 | | - |
255 | | - def abort_seq(self, seq: "Sequence") -> "Sequence": |
256 | | - """ |
257 | | - Remove sequence from the batch. |
258 | | - """ |
259 | | - if not seq.check_finish(): |
260 | | - seq.status = RequestStatus.ABORTED |
261 | | - self.sequences_set.discard(seq) |
262 | | - return seq |
263 | | - |
264 | | - def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None: |
265 | | - """ |
266 | | - Add new sequence to batch |
267 | | -
|
268 | | - Args: |
269 | | - seqs (List["Sequence"]): The list of new sequences. |
270 | | - """ |
271 | | - # covnert single sequence to list |
272 | | - if isinstance(seqs, Sequence): |
273 | | - seqs = [seqs] |
274 | | - |
275 | | - for seq in seqs: |
276 | | - if seq in self.sequences_set: |
277 | | - logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") |
278 | | - continue |
279 | | - self.sequences_set.add(seq) |
280 | | - |
281 | | - def del_seq(self, seq: Sequence) -> Sequence: |
282 | | - """ |
283 | | - Delete sequence in batch |
284 | | - """ |
285 | | - self.sequences_set.discard(seq) |
286 | | - |
287 | | - @property |
288 | | - def is_empty(self) -> None: |
289 | | - """ |
290 | | - Check whether sequences_set is empty. |
291 | | - """ |
292 | | - return not self.sequences_set |
293 | | - |
294 | | - def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None: |
295 | | - """ |
296 | | - Add an output token for each sentence in the batch. |
297 | | -
|
298 | | - Args: |
299 | | - tokens (List[int]): A batch of tokens |
300 | | - """ |
301 | | - |
302 | | - if isinstance(tokens, torch.Tensor): |
303 | | - tokens = tokens.tolist() |
304 | | - |
305 | | - assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size." |
306 | | - |
307 | | - for seq, token in zip(self.sequences_set, tokens): |
308 | | - if not isinstance(token, list): |
309 | | - if not isinstance(token, int): |
310 | | - raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.") |
311 | | - token = [token] |
312 | | - seq.output_token_id += token |
313 | | - seq.check_finish() |
314 | | - |
315 | | - def get_batch_size(self) -> int: |
316 | | - """ |
317 | | - Get batch_size of this batch |
318 | | - """ |
319 | | - return len(self.sequences_set) |
320 | | - |
321 | | - def get_batch_inputs(self) -> torch.LongTensor: |
322 | | - """ |
323 | | - Get bacth inputs for forward inference computation. |
324 | | - """ |
325 | | - |
326 | | - input_list = [] |
327 | | - |
328 | | - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." |
329 | | - |
330 | | - for seq in self.sequences_set: |
331 | | - if self.is_prompts: |
332 | | - if seq.output_len > 0: |
333 | | - input_list.append(seq.input_token_id + seq.output_token_id) |
334 | | - else: |
335 | | - input_list.append(seq.input_token_id) |
336 | | - else: |
337 | | - input_list.append([seq.output_token_id[-1]]) |
338 | | - |
339 | | - max_seq_len = max(len(sub_list) for sub_list in input_list) |
340 | | - |
341 | | - # We assume that all the padding_id in seq are the same at present. |
342 | | - return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int) |
343 | | - |
344 | | - def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: |
345 | | - """ |
346 | | - Flattening the input tokens. |
347 | | - """ |
348 | | - input_list = [] |
349 | | - |
350 | | - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." |
351 | | - |
352 | | - for seq in self.sequences_set: |
353 | | - if self.is_prompts: |
354 | | - input_list.extend(seq.input_token_id) |
355 | | - else: |
356 | | - input_list.append(seq.output_token_id[-1]) |
357 | | - |
358 | | - return torch.tensor(input_list, dtype=torch.long, device=self.device) |
359 | | - |
360 | | - def get_sequence_lengths(self): |
361 | | - """ |
362 | | - Get the input_len of each sentence in this batch. |
363 | | - """ |
364 | | - len_list = [] |
365 | | - |
366 | | - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." |
367 | | - |
368 | | - for seq in self.sequences_set: |
369 | | - len_list.append(seq.sentence_len) |
370 | | - |
371 | | - return torch.tensor(len_list, dtype=torch.int, device=self.device) |
372 | | - |
373 | | - def get_attn_mask(self) -> torch.Tensor: |
374 | | - """ |
375 | | - Generate and return attention mask. |
376 | | - """ |
377 | | - assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." |
378 | | - |
379 | | - past_values = [] |
380 | | - # We assume that all the padding_id in seq are the same at present. |
381 | | - padding_id = self.sequences_set[0].pad_token_id |
382 | | - |
383 | | - for seq in self.sequences_set: |
384 | | - past_values.append(seq.input_token_id + seq.output_token_id) |
385 | | - |
386 | | - max_seq_len = max(len(sub_list) for sub_list in past_values) |
387 | | - attn_mask = _make_tensor_with_pad( |
388 | | - past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device |
389 | | - ) |
390 | | - |
391 | | - return attn_mask.ne(padding_id).long() |
392 | | - |
393 | | - def __repr__(self) -> str: |
394 | | - return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" |
395 | | - |
396 | | - |
397 | 169 | def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: |
398 | 170 | assert len(x) <= max_len |
399 | 171 | return [pad] * (max_len - len(x)) + x |
400 | | - |
401 | | - |
402 | | -def _make_tensor_with_pad( |
403 | | - x: Union[List[List[int]], List[int]], |
404 | | - max_len: int, |
405 | | - pad: int, |
406 | | - dtype: torch.dtype, |
407 | | - device: Union[str, torch.device] = "cuda", |
408 | | - pin_memory: bool = False, |
409 | | -): |
410 | | - padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] |
411 | | - return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu") |
0 commit comments