Skip to content

Commit ccd5036

Browse files
committed
rm deprecated cls & deps
1 parent d1dadf0 commit ccd5036

File tree

2 files changed

+2
-290
lines changed

2 files changed

+2
-290
lines changed

colossalai/inference/struct.py

Lines changed: 1 addition & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import enum
22
from dataclasses import dataclass
3-
from typing import Any, List, Tuple, Union
3+
from typing import Any, List
44

5-
import torch
6-
from ordered_set import OrderedSet
7-
8-
from colossalai.inference.flash_decoding_utils import FDIntermTensors
95
from colossalai.logging import get_dist_logger
106

117
logger = get_dist_logger(__name__)
@@ -170,242 +166,6 @@ def __repr__(self) -> str:
170166
)
171167

172168

173-
@dataclass
174-
class BatchInfo:
175-
"""
176-
Information to be passed and used for a batch of sequences.
177-
"""
178-
179-
max_batch_size: int
180-
kv_max_split_num: int
181-
num_heads: int
182-
head_dim: int
183-
sequences_set: OrderedSet[Sequence] = None
184-
is_prompts: bool = True
185-
device: torch.device = None
186-
dtype: torch.dtype = None
187-
fd_inter_tensor: FDIntermTensors = None
188-
189-
def __post_init__(self):
190-
if self.device is None:
191-
self.device = torch.cuda.current_device()
192-
if self.sequences_set is None:
193-
self.sequences_set = OrderedSet()
194-
if self.fd_inter_tensor is None:
195-
self.fd_inter_tensor = FDIntermTensors()
196-
197-
def init_fd_tensors(self):
198-
if not self.fd_inter_tensor.is_initialized:
199-
self.fd_inter_tensor.initialize(
200-
max_batch_size=self.max_batch_size,
201-
num_attn_heads=self.num_heads,
202-
kv_max_split_num=self.kv_max_split_num,
203-
head_dim=self.head_dim,
204-
dtype=self.dtype,
205-
device=self.device,
206-
)
207-
208-
def get_block_table_tensor(self) -> None:
209-
tesnor_list = []
210-
block_table = None
211-
212-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
213-
214-
for seq in self.sequences_set:
215-
block_table = seq.block_table
216-
assert (
217-
block_table is not None
218-
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
219-
tesnor_list.append(seq.block_table)
220-
221-
block_table = torch.stack(tesnor_list)
222-
return block_table
223-
224-
def clear_batch(self) -> None:
225-
"""
226-
Clear sequence set and block table if we need to abort this batch.
227-
Prefill: clear sequence set and move them to running batch(external)
228-
Decoding: mark unfinished sequences as aborted.
229-
"""
230-
if self.is_prompts:
231-
self.sequences_set.clear()
232-
else:
233-
for seq in self.sequences_set:
234-
seq.mark_aborted()
235-
if seq.check_finish():
236-
seq.mark_finished()
237-
238-
self.sequences_set.clear()
239-
240-
def fliter_batch(self) -> List["Sequence"]:
241-
"""
242-
Remove completed sentences from a batch.
243-
244-
Returns:
245-
List["Sequence"]: List of finished sequences.
246-
"""
247-
finish_seqs = []
248-
for seq in self.sequences_set:
249-
if seq.check_finish():
250-
finish_seqs.append(seq)
251-
for finish_seq in finish_seqs:
252-
self.sequences_set.discard(finish_seq)
253-
return finish_seqs
254-
255-
def abort_seq(self, seq: "Sequence") -> "Sequence":
256-
"""
257-
Remove sequence from the batch.
258-
"""
259-
if not seq.check_finish():
260-
seq.status = RequestStatus.ABORTED
261-
self.sequences_set.discard(seq)
262-
return seq
263-
264-
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
265-
"""
266-
Add new sequence to batch
267-
268-
Args:
269-
seqs (List["Sequence"]): The list of new sequences.
270-
"""
271-
# covnert single sequence to list
272-
if isinstance(seqs, Sequence):
273-
seqs = [seqs]
274-
275-
for seq in seqs:
276-
if seq in self.sequences_set:
277-
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
278-
continue
279-
self.sequences_set.add(seq)
280-
281-
def del_seq(self, seq: Sequence) -> Sequence:
282-
"""
283-
Delete sequence in batch
284-
"""
285-
self.sequences_set.discard(seq)
286-
287-
@property
288-
def is_empty(self) -> None:
289-
"""
290-
Check whether sequences_set is empty.
291-
"""
292-
return not self.sequences_set
293-
294-
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
295-
"""
296-
Add an output token for each sentence in the batch.
297-
298-
Args:
299-
tokens (List[int]): A batch of tokens
300-
"""
301-
302-
if isinstance(tokens, torch.Tensor):
303-
tokens = tokens.tolist()
304-
305-
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
306-
307-
for seq, token in zip(self.sequences_set, tokens):
308-
if not isinstance(token, list):
309-
if not isinstance(token, int):
310-
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
311-
token = [token]
312-
seq.output_token_id += token
313-
seq.check_finish()
314-
315-
def get_batch_size(self) -> int:
316-
"""
317-
Get batch_size of this batch
318-
"""
319-
return len(self.sequences_set)
320-
321-
def get_batch_inputs(self) -> torch.LongTensor:
322-
"""
323-
Get bacth inputs for forward inference computation.
324-
"""
325-
326-
input_list = []
327-
328-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
329-
330-
for seq in self.sequences_set:
331-
if self.is_prompts:
332-
if seq.output_len > 0:
333-
input_list.append(seq.input_token_id + seq.output_token_id)
334-
else:
335-
input_list.append(seq.input_token_id)
336-
else:
337-
input_list.append([seq.output_token_id[-1]])
338-
339-
max_seq_len = max(len(sub_list) for sub_list in input_list)
340-
341-
# We assume that all the padding_id in seq are the same at present.
342-
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
343-
344-
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
345-
"""
346-
Flattening the input tokens.
347-
"""
348-
input_list = []
349-
350-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
351-
352-
for seq in self.sequences_set:
353-
if self.is_prompts:
354-
input_list.extend(seq.input_token_id)
355-
else:
356-
input_list.append(seq.output_token_id[-1])
357-
358-
return torch.tensor(input_list, dtype=torch.long, device=self.device)
359-
360-
def get_sequence_lengths(self):
361-
"""
362-
Get the input_len of each sentence in this batch.
363-
"""
364-
len_list = []
365-
366-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
367-
368-
for seq in self.sequences_set:
369-
len_list.append(seq.sentence_len)
370-
371-
return torch.tensor(len_list, dtype=torch.int, device=self.device)
372-
373-
def get_attn_mask(self) -> torch.Tensor:
374-
"""
375-
Generate and return attention mask.
376-
"""
377-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
378-
379-
past_values = []
380-
# We assume that all the padding_id in seq are the same at present.
381-
padding_id = self.sequences_set[0].pad_token_id
382-
383-
for seq in self.sequences_set:
384-
past_values.append(seq.input_token_id + seq.output_token_id)
385-
386-
max_seq_len = max(len(sub_list) for sub_list in past_values)
387-
attn_mask = _make_tensor_with_pad(
388-
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
389-
)
390-
391-
return attn_mask.ne(padding_id).long()
392-
393-
def __repr__(self) -> str:
394-
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
395-
396-
397169
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
398170
assert len(x) <= max_len
399171
return [pad] * (max_len - len(x)) + x
400-
401-
402-
def _make_tensor_with_pad(
403-
x: Union[List[List[int]], List[int]],
404-
max_len: int,
405-
pad: int,
406-
dtype: torch.dtype,
407-
device: Union[str, torch.device] = "cuda",
408-
pin_memory: bool = False,
409-
):
410-
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
411-
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")

tests/test_infer/test_config_and_struct.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import colossalai
44
from colossalai.inference.config import InferenceConfig
5-
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
5+
from colossalai.inference.struct import RequestStatus, Sequence
66
from colossalai.testing import rerun_if_address_is_in_use, spawn
77

88

@@ -20,27 +20,6 @@ def check_config_and_inference():
2020
max_output_len=256,
2121
)
2222

23-
sequence2 = Sequence(
24-
request_id=2,
25-
prompt="bcd",
26-
input_token_id=[4, 5, 6],
27-
block_size=16,
28-
sample_params=None,
29-
eos_token_id=2,
30-
pad_token_id=2,
31-
max_output_len=256,
32-
)
33-
34-
sequence3 = Sequence(
35-
request_id=3,
36-
prompt="efg",
37-
input_token_id=[7, 8, 9],
38-
block_size=16,
39-
sample_params=None,
40-
eos_token_id=2,
41-
pad_token_id=2,
42-
max_output_len=256,
43-
)
4423
sequence.mark_running()
4524
assert sequence.status == RequestStatus.RUNNING
4625
sequence.recycle()
@@ -51,33 +30,6 @@ def check_config_and_inference():
5130
assert sequence.output_len == 0
5231
assert sequence.check_finish() == False
5332

54-
batch = BatchInfo(
55-
max_batch_size=8,
56-
kv_max_split_num=16,
57-
num_heads=2,
58-
head_dim=128,
59-
)
60-
batch.add_seqs([sequence])
61-
batch.add_seqs([sequence2, sequence3])
62-
63-
# add duplicated sequence to test that it will not be counted twice
64-
batch.add_seqs([sequence])
65-
66-
assert batch.is_empty == False
67-
assert batch.get_batch_size() == 3
68-
batch.update_batch_tokens([1, 2, 3])
69-
seq = batch.abort_seq(sequence)
70-
seq2 = batch.fliter_batch()[0]
71-
72-
assert batch.get_batch_size() == 1
73-
assert seq.output_len == 1
74-
assert seq.output_token_id == [1]
75-
assert seq2.output_len == 1
76-
assert seq2.output_token_id == [2]
77-
78-
batch.clear_batch()
79-
assert batch.is_empty == True
80-
8133

8234
def run_dist(rank, world_size, port):
8335
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")

0 commit comments

Comments
 (0)