Skip to content

Commit 83b98f9

Browse files
committed
rm deprecated cls (removed dep)
1 parent d1dadf0 commit 83b98f9

File tree

1 file changed

+1
-227
lines changed

1 file changed

+1
-227
lines changed

colossalai/inference/struct.py

Lines changed: 1 addition & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import enum
22
from dataclasses import dataclass
3-
from typing import Any, List, Tuple, Union
3+
from typing import Any, List, Union
44

55
import torch
6-
from ordered_set import OrderedSet
76

8-
from colossalai.inference.flash_decoding_utils import FDIntermTensors
97
from colossalai.logging import get_dist_logger
108

119
logger = get_dist_logger(__name__)
@@ -170,230 +168,6 @@ def __repr__(self) -> str:
170168
)
171169

172170

173-
@dataclass
174-
class BatchInfo:
175-
"""
176-
Information to be passed and used for a batch of sequences.
177-
"""
178-
179-
max_batch_size: int
180-
kv_max_split_num: int
181-
num_heads: int
182-
head_dim: int
183-
sequences_set: OrderedSet[Sequence] = None
184-
is_prompts: bool = True
185-
device: torch.device = None
186-
dtype: torch.dtype = None
187-
fd_inter_tensor: FDIntermTensors = None
188-
189-
def __post_init__(self):
190-
if self.device is None:
191-
self.device = torch.cuda.current_device()
192-
if self.sequences_set is None:
193-
self.sequences_set = OrderedSet()
194-
if self.fd_inter_tensor is None:
195-
self.fd_inter_tensor = FDIntermTensors()
196-
197-
def init_fd_tensors(self):
198-
if not self.fd_inter_tensor.is_initialized:
199-
self.fd_inter_tensor.initialize(
200-
max_batch_size=self.max_batch_size,
201-
num_attn_heads=self.num_heads,
202-
kv_max_split_num=self.kv_max_split_num,
203-
head_dim=self.head_dim,
204-
dtype=self.dtype,
205-
device=self.device,
206-
)
207-
208-
def get_block_table_tensor(self) -> None:
209-
tesnor_list = []
210-
block_table = None
211-
212-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
213-
214-
for seq in self.sequences_set:
215-
block_table = seq.block_table
216-
assert (
217-
block_table is not None
218-
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
219-
tesnor_list.append(seq.block_table)
220-
221-
block_table = torch.stack(tesnor_list)
222-
return block_table
223-
224-
def clear_batch(self) -> None:
225-
"""
226-
Clear sequence set and block table if we need to abort this batch.
227-
Prefill: clear sequence set and move them to running batch(external)
228-
Decoding: mark unfinished sequences as aborted.
229-
"""
230-
if self.is_prompts:
231-
self.sequences_set.clear()
232-
else:
233-
for seq in self.sequences_set:
234-
seq.mark_aborted()
235-
if seq.check_finish():
236-
seq.mark_finished()
237-
238-
self.sequences_set.clear()
239-
240-
def fliter_batch(self) -> List["Sequence"]:
241-
"""
242-
Remove completed sentences from a batch.
243-
244-
Returns:
245-
List["Sequence"]: List of finished sequences.
246-
"""
247-
finish_seqs = []
248-
for seq in self.sequences_set:
249-
if seq.check_finish():
250-
finish_seqs.append(seq)
251-
for finish_seq in finish_seqs:
252-
self.sequences_set.discard(finish_seq)
253-
return finish_seqs
254-
255-
def abort_seq(self, seq: "Sequence") -> "Sequence":
256-
"""
257-
Remove sequence from the batch.
258-
"""
259-
if not seq.check_finish():
260-
seq.status = RequestStatus.ABORTED
261-
self.sequences_set.discard(seq)
262-
return seq
263-
264-
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
265-
"""
266-
Add new sequence to batch
267-
268-
Args:
269-
seqs (List["Sequence"]): The list of new sequences.
270-
"""
271-
# covnert single sequence to list
272-
if isinstance(seqs, Sequence):
273-
seqs = [seqs]
274-
275-
for seq in seqs:
276-
if seq in self.sequences_set:
277-
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
278-
continue
279-
self.sequences_set.add(seq)
280-
281-
def del_seq(self, seq: Sequence) -> Sequence:
282-
"""
283-
Delete sequence in batch
284-
"""
285-
self.sequences_set.discard(seq)
286-
287-
@property
288-
def is_empty(self) -> None:
289-
"""
290-
Check whether sequences_set is empty.
291-
"""
292-
return not self.sequences_set
293-
294-
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
295-
"""
296-
Add an output token for each sentence in the batch.
297-
298-
Args:
299-
tokens (List[int]): A batch of tokens
300-
"""
301-
302-
if isinstance(tokens, torch.Tensor):
303-
tokens = tokens.tolist()
304-
305-
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
306-
307-
for seq, token in zip(self.sequences_set, tokens):
308-
if not isinstance(token, list):
309-
if not isinstance(token, int):
310-
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
311-
token = [token]
312-
seq.output_token_id += token
313-
seq.check_finish()
314-
315-
def get_batch_size(self) -> int:
316-
"""
317-
Get batch_size of this batch
318-
"""
319-
return len(self.sequences_set)
320-
321-
def get_batch_inputs(self) -> torch.LongTensor:
322-
"""
323-
Get bacth inputs for forward inference computation.
324-
"""
325-
326-
input_list = []
327-
328-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
329-
330-
for seq in self.sequences_set:
331-
if self.is_prompts:
332-
if seq.output_len > 0:
333-
input_list.append(seq.input_token_id + seq.output_token_id)
334-
else:
335-
input_list.append(seq.input_token_id)
336-
else:
337-
input_list.append([seq.output_token_id[-1]])
338-
339-
max_seq_len = max(len(sub_list) for sub_list in input_list)
340-
341-
# We assume that all the padding_id in seq are the same at present.
342-
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
343-
344-
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
345-
"""
346-
Flattening the input tokens.
347-
"""
348-
input_list = []
349-
350-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
351-
352-
for seq in self.sequences_set:
353-
if self.is_prompts:
354-
input_list.extend(seq.input_token_id)
355-
else:
356-
input_list.append(seq.output_token_id[-1])
357-
358-
return torch.tensor(input_list, dtype=torch.long, device=self.device)
359-
360-
def get_sequence_lengths(self):
361-
"""
362-
Get the input_len of each sentence in this batch.
363-
"""
364-
len_list = []
365-
366-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
367-
368-
for seq in self.sequences_set:
369-
len_list.append(seq.sentence_len)
370-
371-
return torch.tensor(len_list, dtype=torch.int, device=self.device)
372-
373-
def get_attn_mask(self) -> torch.Tensor:
374-
"""
375-
Generate and return attention mask.
376-
"""
377-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
378-
379-
past_values = []
380-
# We assume that all the padding_id in seq are the same at present.
381-
padding_id = self.sequences_set[0].pad_token_id
382-
383-
for seq in self.sequences_set:
384-
past_values.append(seq.input_token_id + seq.output_token_id)
385-
386-
max_seq_len = max(len(sub_list) for sub_list in past_values)
387-
attn_mask = _make_tensor_with_pad(
388-
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
389-
)
390-
391-
return attn_mask.ne(padding_id).long()
392-
393-
def __repr__(self) -> str:
394-
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
395-
396-
397171
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
398172
assert len(x) <= max_len
399173
return [pad] * (max_len - len(x)) + x

0 commit comments

Comments
 (0)