Skip to content

Commit 987ee76

Browse files
committed
merge
2 parents 68479ee + 95efad7 commit 987ee76

13 files changed

+524
-174
lines changed

colossalai/inference/config.py

+44-34
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Optional, Union
1010

1111
import torch
12-
import torch.nn as nn
12+
import torch.distributed as dist
1313

1414
GibiByte = 1024**3
1515

@@ -21,44 +21,44 @@ class InferenceConfig:
2121
"""The inference configuration.
2222
2323
Args:
24-
model: Path or nn.Module of this model.
25-
tokenizer: Path of the tokenizer to use.
26-
tokenizer_mode: "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.
27-
trust_remote_code: Whether to trust remote code from huggingface.
28-
max_batch_size: Maximum batch size.
29-
max_output_len: Maximum output length.
30-
max_input_len: Maximum input length.
31-
block_size: The number of blocks in a logical block.
32-
dtype: The data type for weights and activations.
33-
tp_size: Tensor parallel size.
34-
pp_size: Pipeline parallel size.
35-
max_seq_len: Maximum length of input sentence.
36-
quant_mode: Quantization mode.
37-
revision: The specific version(a branch, name, a commit id, or a tag name) of model to use.
38-
beam_width: The maximum beam width used to initialize KV Cache.
24+
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
25+
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
26+
max_batch_size (int): Maximum batch size.
27+
max_output_len (int): Maximum output length.
28+
max_input_len (int): Maximum input length.
29+
block_size (int): The number of blocks in a logical block.
30+
dtype (Union[str, torch.dtype]): The data type for weights and activations.
31+
tp_size (int): Tensor parallel size.
32+
pp_size (int): Pipeline parallel size.
33+
max_seq_len (int): Maximum length of input sentence.
34+
beam_width (int): The maximum beam width used to initialize KV Cache.
3935
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
40-
prefill_ratio: A controling ratio for prefill and decoding in running list, we will do a step of prefill
36+
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
4137
when the actual value exceeds this ratio.
38+
quant_mode (Optional[str]): Quantization mode.
39+
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
4240
"""
4341

44-
model: Union[str, nn.Module]
45-
tokenizer: str = None
46-
tokenizer_mode: str = "auto"
47-
trust_remote_code: bool = False
48-
max_batch_size: int = None
42+
micro_batch_size: int = 1
43+
micro_batch_buffer_size: int = None
44+
max_batch_size: int = 8
4945
max_output_len: int = 256
5046
max_input_len: int = 256
5147
block_size: int = 16
5248
dtype: Union[str, torch.dtype] = torch.float32
5349
tp_size: int = 1
5450
pp_size: int = 1
55-
max_seq_len: Optional[int] = None
56-
quant_mode: Optional[str] = None
57-
revision: Optional[str] = None
58-
beam_width: int = 1
51+
max_seq_len: int = 512
5952
# TODO: beam search is not support for now
60-
prefill_ratio: Optional[float] = 1.2
53+
beam_width: int = 1
6154
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
55+
prefill_ratio: Optional[float] = 1.2
56+
quant_mode: Optional[str] = None
57+
revision: Optional[str] = None
58+
59+
def __post_init__(self):
60+
self._init_batch_size()
61+
self._verify_config()
6262

6363
def _init_batch_size(self):
6464
"""
@@ -81,10 +81,20 @@ def _init_batch_size(self):
8181
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
8282
)
8383

84-
def __post_init__(self):
85-
self._init_batch_size()
86-
self._verify_args()
87-
88-
def _verify_args(self):
89-
if self.tokenizer_mode not in ["auto", "slow"]:
90-
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")
84+
def _verify_config(self) -> None:
85+
"""
86+
Verify the input config
87+
"""
88+
assert (
89+
self.tp_size * self.pp_size == dist.get_world_size()
90+
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
91+
assert self.dtype in [
92+
"fp16",
93+
"fp32",
94+
"bf16",
95+
torch.float32,
96+
torch.float16,
97+
torch.bfloat16,
98+
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
99+
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
100+
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"

colossalai/inference/core/engine.py

+199-32
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,232 @@
1-
from logging import Logger
2-
from typing import Optional
1+
from itertools import count
2+
from typing import List, Optional, Union
33

4-
from transformers import AutoConfig
4+
import torch
5+
import torch.nn as nn
6+
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
57

8+
from colossalai.cluster import ProcessGroupMesh
69
from colossalai.inference.config import InferenceConfig
10+
from colossalai.inference.modeling.policy import model_policy_map
11+
from colossalai.inference.struct import Sequence
12+
from colossalai.logging import get_dist_logger
13+
from colossalai.pipeline.stage_manager import PipelineStageManager
14+
from colossalai.shardformer import ShardConfig, ShardFormer
15+
from colossalai.shardformer.policies.base_policy import Policy
16+
17+
from .request_handler import RequestHandler
18+
19+
PP_AXIS, TP_AXIS = 0, 1
20+
21+
_supported_models = [
22+
"LlamaForCausalLM",
23+
]
724

825

926
class InferenceEngine:
10-
"""
11-
InferenceEngine is the core component for Inference.
1227

13-
It is responsible for launch the inference process, including:
14-
- Initialize model and distributed training environment(if needed)
15-
- Launch request_handler and corresponding kv cache manager
16-
- Receive requests and generate texts.
17-
- Log the generation process
28+
"""
29+
InferenceEngine which manages the inference process..
1830
1931
Args:
20-
tokenizer: Path of the tokenizer to use.
21-
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
32+
model (nn.Module): Path or nn.Module of this model.
33+
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
34+
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
2235
verbose (bool): Determine whether or not to log the generation process.
36+
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
2337
"""
2438

2539
def __init__(
2640
self,
27-
tokenizer: str = None,
41+
model: nn.Module,
42+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
2843
inference_config: Optional["InferenceConfig"] = None,
2944
verbose: bool = False,
45+
model_policy: Policy = None,
3046
) -> None:
3147
assert inference_config, "Please provide inference_config."
32-
33-
self._init_model()
34-
# cache_config may need to be modified later.
35-
# self.request_handler = RequestHandler(cache_config)
3648
self.tokenizer = tokenizer
37-
self.hf_model_config = AutoConfig.from_pretrained(
38-
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
49+
self.inference_config = inference_config
50+
self.model_config = model.config
51+
52+
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
53+
self.dtype = torch.float32
54+
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
55+
self.dtype = torch.float16
56+
model.half()
57+
else:
58+
self.dtype = torch.bfloat16
59+
model.to(torch.bfloat16)
60+
61+
if model_policy is None:
62+
model_policy = model_policy_map[self.model_config.model_type]()
63+
64+
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
65+
66+
self.model = self._shardformer(
67+
model,
68+
model_policy,
69+
None,
70+
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
3971
)
72+
73+
self.verbose = verbose
4074
if verbose:
41-
self.logger = Logger()
75+
self.logger = get_dist_logger(__name__)
76+
77+
self.request_handler = RequestHandler(self.inference_config, self.model_config)
78+
self.counter = count()
79+
80+
def _verify_config(self) -> None:
81+
"""
82+
Verify the input config
83+
"""
84+
if not isinstance(self.model, nn.Module):
85+
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
86+
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
87+
self.tokenizer, PreTrainedTokenizer
88+
):
89+
raise TypeError(
90+
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
91+
)
92+
assert (
93+
self.model.__class__.__name__ in _supported_models
94+
), f"Model {self.model.__class__.__name__} is not supported."
95+
96+
def _shardformer(
97+
self,
98+
model: nn.Module,
99+
model_policy: Policy,
100+
stage_manager: PipelineStageManager = None,
101+
tp_group: ProcessGroupMesh = None,
102+
) -> nn.Module:
103+
"""
104+
Initialize ShardConfig and replace the model with shardformer.
105+
106+
Args:
107+
model (nn.Module): Path or nn.Module of this model.
108+
model_policy (Policy): The policy to shardformer model which is determined by the model type.
109+
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
110+
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
111+
112+
Returns:
113+
nn.Module: _description_
114+
"""
115+
shardconfig = ShardConfig(
116+
tensor_parallel_process_group=tp_group,
117+
pipeline_stage_manager=stage_manager,
118+
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
119+
enable_fused_normalization=False,
120+
enable_all_optimization=False,
121+
enable_flash_attention=False,
122+
enable_jit_fused=False,
123+
enable_sequence_parallelism=False,
124+
extra_kwargs={"quant": self.inference_config.quant_mode},
125+
)
126+
shardformer = ShardFormer(shard_config=shardconfig)
127+
shard_model, _ = shardformer.optimize(model, model_policy)
128+
return shard_model.cuda()
42129

43-
def _init_model(self):
130+
def generate(
131+
self,
132+
generation_config: GenerationConfig = None,
133+
) -> List[str]:
44134
"""
45-
Initialize model and distributed training environment(if needed).
46-
May need to provide two different initialization methods:
47-
1. 用户自定义(from local path)
48-
2. 从checkpoint加载(hugging face)
135+
Executing the inference step.
136+
137+
Args:
138+
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
139+
140+
Returns:
141+
List[str]: Inference result returned by one generation.
49142
"""
50143

51-
def _verify_config(self):
144+
self.generation_config = generation_config
145+
146+
output_list = []
147+
148+
while self.request_handler.check_unfinished_seqs():
149+
output_list += self.step()
150+
151+
return output_list
152+
153+
def add_request(
154+
self,
155+
requests_id: List[int] = None,
156+
prompts: List[str] = None,
157+
prompts_token_ids: List[int] = None,
158+
) -> None:
52159
"""
53-
Verify the configuration to avoid potential bugs.
160+
Add requests.
161+
162+
Args:
163+
requests_id (List[int], optional): The request ID. Defaults to None.
164+
prompts (Union[List[str], optional): Input prompts. Defaults to None.
165+
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
54166
"""
55167

56-
def generate(self):
57-
pass
168+
block_size = self.inference_config.block_size
58169

59-
def step(self):
170+
if prompts_token_ids is None:
171+
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
172+
prompts_token_ids = []
173+
for prompt in prompts:
174+
prompts_token_ids.append(self.tokenizer.encode(prompt))
175+
176+
prompts_num = len(prompts_token_ids)
177+
178+
for i in range(prompts_num):
179+
if requests_id:
180+
request_id = requests_id[i]
181+
else:
182+
request_id = next(self.counter)
183+
if prompts == None:
184+
prompt = None
185+
else:
186+
prompt = prompts[i]
187+
sequence = Sequence(
188+
request_id,
189+
prompt,
190+
prompts_token_ids[i],
191+
block_size,
192+
None,
193+
None,
194+
self.tokenizer.eos_token_id,
195+
self.inference_config.max_output_len,
196+
)
197+
self.request_handler.add_sequence(sequence)
198+
199+
def step(self) -> List[str]:
60200
"""
61201
In each step, do the follows:
62-
1. Run request_handler to update the kv cache and running input_ids
202+
1. Run RequestHandler.schedule() and get the batch used for inference.
63203
2. Run model to generate the next token
64-
3. Check whether there is finied request and decode
204+
3. Update waiting list and running list in RequestHandler and get finished sequences.
205+
4. Decode and return finished sequences.
206+
207+
Returns:
208+
List[str]: Decoded finished sequences generated by one step.
65209
"""
210+
211+
if self.verbose:
212+
self.logger.info("Running generation step")
213+
214+
output_list = []
215+
self.request_handler.schedule()
216+
217+
# Uncomment if the development of RequestHandler is completed.
218+
# logits = self.model(batch)
219+
# self.request_handler.search_tokens(logits, self.generation_config)
220+
221+
finished_sequences = self.request_handler.update()
222+
223+
# Decode completed sentences.
224+
for seq in finished_sequences:
225+
if seq.prompt:
226+
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
227+
output_list.append(seq.prompt + output_str)
228+
else:
229+
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
230+
output_list.append(output_str)
231+
232+
return output_list

colossalai/inference/core/request_handler.py

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class RequestHandler:
6060
6161
Args:
6262
inference_config: Configuration for initialize and manage kv cache.
63+
model_config: Configuration for model
6364
"""
6465

6566
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:

0 commit comments

Comments
 (0)