Skip to content

Commit 776237e

Browse files
authored
[Inference]Add BatchInferState, Sequence and InferConfig (#5149)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct
1 parent 7fe0364 commit 776237e

File tree

5 files changed

+279
-34
lines changed

5 files changed

+279
-34
lines changed

colossalai/inference/config.py

Lines changed: 0 additions & 7 deletions
This file was deleted.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Optional, Union
2+
from dataclasses import dataclass
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
@dataclass
8+
class InferenceConfig:
9+
"""The inference configuration.
10+
11+
Args:
12+
model: Path or nn.Module of this model.
13+
tokenizer: Path of the tokenizer to use.
14+
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
15+
trust_remote_code: Whether to trust remote code from huggingface.
16+
max_batch_size: Maximum batch size.
17+
max_output_len: Maximum output length.
18+
max_input_len: Maximum input length.
19+
block_size: The number of blocks in a logical block.
20+
gpu_utilization_rate: Maximum GPU memory usage ratio.
21+
dtype: The data type for weights and activations.
22+
tp_size: Tensor parallel size.
23+
pp_size: Pipeline parallel size.
24+
max_seq_len: Maximum length of input sentence.
25+
quant_mode: Quantization mode.
26+
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
27+
"""
28+
29+
model: Union[str, nn.Module]
30+
tokenizer: str = None
31+
tokenizer_mode: str = "auto"
32+
trust_remote_code: bool = False
33+
max_batch_size: int = 8
34+
max_output_len: int = 256
35+
max_input_len: int = 256
36+
block_size: int = 16
37+
gpu_utilization_rate: float = 0.7
38+
dtype: Union[str, torch.dtype] = torch.float32
39+
tp_size: int = 1
40+
pp_size: int = 1
41+
max_seq_len: Optional[int] = None
42+
quant_mode: Optional[str] = None
43+
revision: Optional[str] = None
44+
45+
def __post_init__(self):
46+
self._verify_args()
47+
48+
def _verify_args(self):
49+
if self.gpu_utilization_rate > 1.0:
50+
raise ValueError(
51+
f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}."
52+
)
53+
if self.tokenizer_mode not in ["auto", "slow"]:
54+
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")

colossalai/inference/core/engine.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from logging import Logger
22
from typing import Optional
33

4-
from .request_handler import RequestHandler
4+
from transformers import AutoConfig
55

6+
from .config import InferenceConfig
67

7-
class InferEngine:
8+
9+
class InferenceEngine:
810
"""
9-
InferEngine is the core component for Inference.
11+
InferenceEngine is the core component for Inference.
1012
1113
It is responsible for launch the inference process, including:
1214
- Initialize model and distributed training environment(if needed)
@@ -15,37 +17,27 @@ class InferEngine:
1517
- Log the generation process
1618
1719
Args:
18-
colossal_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
19-
model_config : The configuration for the model.
20-
parallel_config: The configuration for parallelize model.
21-
cache_config : Configuration for initialize and manage kv cache.
22-
tokenizer (Tokenizer): The tokenizer to be used for inference.
23-
use_logger (bool): Determine whether or not to log the generation process.
20+
tokenizer: Path of the tokenizer to use.
21+
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
22+
verbose (bool): Determine whether or not to log the generation process.
2423
"""
2524

2625
def __init__(
2726
self,
28-
model_config,
29-
cache_config,
30-
parallel_config,
31-
tokenizer,
32-
use_logger: bool = False,
33-
colossal_config: Optional["ColossalInferConfig"] = None,
27+
tokenizer: str = None,
28+
inference_config: Optional["InferenceConfig"] = None,
29+
verbose: bool = False,
3430
) -> None:
35-
assert colossal_config or (
36-
model_config and cache_config and parallel_config
37-
), "Please provide colossal_config or model_config, cache_config, parallel_config"
38-
if colossal_config:
39-
model_config, cache_config, parallel_config = colossal_config
40-
41-
self.model_config = model_config
42-
self.cache_config = cache_config
43-
self.parallel_config = parallel_config
44-
self._verify_config()
31+
assert inference_config, "Please provide inference_config."
4532

4633
self._init_model()
47-
self.request_handler = RequestHandler(cache_config)
48-
if use_logger:
34+
# cache_config may need to be modified later.
35+
# self.request_handler = RequestHandler(cache_config)
36+
self.tokenizer = tokenizer
37+
self.hf_model_config = AutoConfig.from_pretrained(
38+
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
39+
)
40+
if verbose:
4941
self.logger = Logger()
5042

5143
def _init_model(self):
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import enum
2+
from dataclasses import dataclass
3+
from typing import Dict, List, Set
4+
5+
6+
class RequsetStatus(enum.Enum):
7+
"""The status of Sentences"""
8+
9+
WAITING = enum.auto()
10+
RUNNING = enum.auto()
11+
ABORTED = enum.auto()
12+
OVERLENGTH = enum.auto()
13+
COMPLETED = enum.auto()
14+
LENGTH_CAPPED = enum.auto()
15+
16+
@staticmethod
17+
def is_finished(status: "RequsetStatus") -> bool:
18+
return status in [
19+
RequsetStatus.OVERLENGTH,
20+
RequsetStatus.COMPLETED,
21+
RequsetStatus.LENGTH_CAPPED,
22+
]
23+
24+
@staticmethod
25+
def is_running(status: "RequsetStatus") -> bool:
26+
return status == RequsetStatus.RUNNING
27+
28+
@staticmethod
29+
def is_waiting(status: "RequsetStatus") -> bool:
30+
return status == RequsetStatus.WAITING
31+
32+
33+
class Sequence:
34+
"""Store information of input sequence.
35+
36+
Args:
37+
request_id: The ID of input sequence.
38+
prompt: The prompt of input sequence.
39+
token_id: The tokens ID of input sequence.
40+
block_size: The block size of input sequence.
41+
sample_params: The sample_params of input sequence.
42+
block_table_index: The index of input sequence in block_table.
43+
"""
44+
45+
def __init__(
46+
self,
47+
request_id: int,
48+
prompt: str,
49+
token_id: List[int],
50+
block_size: int,
51+
sample_params, # SampleParams needs to be imported later.
52+
block_table_index: int,
53+
):
54+
self.request_id = request_id
55+
self.prompt = prompt
56+
self.input_token_id = token_id
57+
self.blokc_size = block_size
58+
self.sample_params = sample_params
59+
self.output_token_id = []
60+
self.status = RequsetStatus.WAITING
61+
self.block_table_index = block_table_index
62+
63+
def get_sentence_len(self) -> None:
64+
"""
65+
Get length of current sentence.
66+
"""
67+
return len(self.input_token_id) + len(self.output_token_id)
68+
69+
def get_input_len(self) -> None:
70+
"""
71+
Get length of input sentence.
72+
"""
73+
return len(self.input_token_id)
74+
75+
def get_output_len(self) -> None:
76+
"""
77+
Get output length of current sentence.
78+
"""
79+
return len(self.output_token_id)
80+
81+
def check_finish(self) -> bool:
82+
"""
83+
Check whether inference is over.
84+
"""
85+
return RequsetStatus.is_finished(self.status)
86+
87+
def __repr__(self) -> str:
88+
return (
89+
f"Request ID(request_id={self.request_id}, "
90+
f"prompt={self.prompt}, "
91+
f"status={self.status.name}, "
92+
f"sample_params={self.sample_params}, "
93+
f"logical block number={len(self._logical_blocks)}"
94+
)
95+
96+
97+
@dataclass
98+
class BatchHandler:
99+
"""
100+
Information to be passed and used for a batch of sequences.
101+
"""
102+
103+
sequences_set: Set[Sequence]
104+
block_table: Dict[int, int]
105+
106+
@classmethod
107+
def init_batch(cls, seqs: List[Sequence]) -> "BatchHandler":
108+
"""
109+
Initializes inference batches by input sentence list.
110+
111+
Args:
112+
seqs (List[Sequence]): List of input sequence.
113+
"""
114+
sequences_set = set()
115+
block_table = {}
116+
for seq in seqs:
117+
if seq in sequences_set:
118+
print("The sequence is already in sequences_set.")
119+
assert (
120+
seq.request_id in block_table
121+
), "The sequence has been added to sequences_set, but it has not been added to block_table."
122+
continue
123+
assert (
124+
seq.request_id not in block_table
125+
), "The sequence has not been added to sequences_set, but it is already in block_table."
126+
127+
sequences_set.add(seq)
128+
block_table[seq.request_id] = seq.block_table_index
129+
130+
return cls(sequences_set=sequences_set, block_table=block_table)
131+
132+
def clear_batch(self) -> None:
133+
"""
134+
Clear sequence set and block table.
135+
"""
136+
for seq in self.sequences_set:
137+
if not seq.check_finish():
138+
seq.status = RequsetStatus.ABORTED
139+
self.sequences_set.clear()
140+
self.block_table.clear()
141+
142+
def fliter_batch(self) -> None:
143+
"""
144+
Remove completed sentences from a batch.
145+
"""
146+
for seq in self.sequences_set:
147+
if seq.check_finish():
148+
self.sequences_set.reomve(seq)
149+
del self.block_table[seq.request_id]
150+
151+
def add_seqs(self, seqs: List[Sequence]) -> None:
152+
"""
153+
Add new sequence to batch
154+
155+
Args:
156+
seqs (List[Sequence]): The list of new sequences.
157+
"""
158+
for seq in seqs:
159+
if seq in self.sequences_set:
160+
print("The sequence is already in sequences_set.")
161+
assert (
162+
seq.request_id in self.block_table
163+
), "The sequence has been added to sequences_set, but it has not been added to block_table."
164+
continue
165+
assert (
166+
seq.request_id not in self.block_table
167+
), "The sequence has not been added to sequences_set, but it is already in block_table."
168+
self.sequences_set.add(seq)
169+
self.block_table[seq.request_id] = seq.block_table_index
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from colossalai.inference.core.config import InferenceConfig
2+
from colossalai.inference.core.inference_struct import BatchHandler, Sequence
3+
4+
5+
def test_config_and_struct():
6+
InferenceConfig("/llama")
7+
sequence = Sequence(
8+
request_id=1,
9+
prompt="abc",
10+
token_id=[1, 2, 3],
11+
block_size=16,
12+
sample_params=None,
13+
block_table_index=1,
14+
)
15+
16+
sequence2 = Sequence(
17+
request_id=2,
18+
prompt="bcd",
19+
token_id=[4, 5, 6],
20+
block_size=16,
21+
sample_params=None,
22+
block_table_index=2,
23+
)
24+
25+
assert sequence.get_sentence_len() == 3
26+
assert sequence.get_input_len() == 3
27+
assert sequence.get_output_len() == 0
28+
assert sequence.check_finish() == False
29+
30+
batch = BatchHandler.init_batch([sequence])
31+
batch.fliter_batch()
32+
batch.add_seqs([sequence2])
33+
batch.clear_batch()
34+
35+
36+
if __name__ == "__main__":
37+
test_config_and_struct()

0 commit comments

Comments
 (0)