|
1 |
| -from logging import Logger |
2 |
| -from typing import Optional |
| 1 | +from itertools import count |
| 2 | +from typing import List, Optional, Union |
3 | 3 |
|
4 |
| -from transformers import AutoConfig |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast |
5 | 7 |
|
| 8 | +from colossalai.cluster import ProcessGroupMesh |
6 | 9 | from colossalai.inference.config import InferenceConfig
|
| 10 | +from colossalai.inference.modeling.policy import model_policy_map |
| 11 | +from colossalai.inference.struct import Sequence |
| 12 | +from colossalai.logging import get_dist_logger |
| 13 | +from colossalai.pipeline.stage_manager import PipelineStageManager |
| 14 | +from colossalai.shardformer import ShardConfig, ShardFormer |
| 15 | +from colossalai.shardformer.policies.base_policy import Policy |
| 16 | + |
| 17 | +from .request_handler import RequestHandler |
| 18 | + |
| 19 | +PP_AXIS, TP_AXIS = 0, 1 |
| 20 | + |
| 21 | +_supported_models = [ |
| 22 | + "LlamaForCausalLM", |
| 23 | +] |
7 | 24 |
|
8 | 25 |
|
9 | 26 | class InferenceEngine:
|
10 |
| - """ |
11 |
| - InferenceEngine is the core component for Inference. |
12 | 27 |
|
13 |
| - It is responsible for launch the inference process, including: |
14 |
| - - Initialize model and distributed training environment(if needed) |
15 |
| - - Launch request_handler and corresponding kv cache manager |
16 |
| - - Receive requests and generate texts. |
17 |
| - - Log the generation process |
| 28 | + """ |
| 29 | + InferenceEngine which manages the inference process.. |
18 | 30 |
|
19 | 31 | Args:
|
20 |
| - tokenizer: Path of the tokenizer to use. |
21 |
| - inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs. |
| 32 | + model (nn.Module): Path or nn.Module of this model. |
| 33 | + tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use. |
| 34 | + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. |
22 | 35 | verbose (bool): Determine whether or not to log the generation process.
|
| 36 | + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. |
23 | 37 | """
|
24 | 38 |
|
25 | 39 | def __init__(
|
26 | 40 | self,
|
27 |
| - tokenizer: str = None, |
| 41 | + model: nn.Module, |
| 42 | + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], |
28 | 43 | inference_config: Optional["InferenceConfig"] = None,
|
29 | 44 | verbose: bool = False,
|
| 45 | + model_policy: Policy = None, |
30 | 46 | ) -> None:
|
31 | 47 | assert inference_config, "Please provide inference_config."
|
32 |
| - |
33 |
| - self._init_model() |
34 |
| - # cache_config may need to be modified later. |
35 |
| - # self.request_handler = RequestHandler(cache_config) |
36 | 48 | self.tokenizer = tokenizer
|
37 |
| - self.hf_model_config = AutoConfig.from_pretrained( |
38 |
| - self.model, trust_remote_code=self.trust_remote_code, revision=self.revision |
| 49 | + self.inference_config = inference_config |
| 50 | + self.model_config = model.config |
| 51 | + |
| 52 | + if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: |
| 53 | + self.dtype = torch.float32 |
| 54 | + elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16: |
| 55 | + self.dtype = torch.float16 |
| 56 | + model.half() |
| 57 | + else: |
| 58 | + self.dtype = torch.bfloat16 |
| 59 | + model.to(torch.bfloat16) |
| 60 | + |
| 61 | + if model_policy is None: |
| 62 | + model_policy = model_policy_map[self.model_config.model_type]() |
| 63 | + |
| 64 | + pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) |
| 65 | + |
| 66 | + self.model = self._shardformer( |
| 67 | + model, |
| 68 | + model_policy, |
| 69 | + None, |
| 70 | + pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None, |
39 | 71 | )
|
| 72 | + |
| 73 | + self.verbose = verbose |
40 | 74 | if verbose:
|
41 |
| - self.logger = Logger() |
| 75 | + self.logger = get_dist_logger(__name__) |
| 76 | + |
| 77 | + self.request_handler = RequestHandler(self.inference_config, self.model_config) |
| 78 | + self.counter = count() |
| 79 | + |
| 80 | + def _verify_config(self) -> None: |
| 81 | + """ |
| 82 | + Verify the input config |
| 83 | + """ |
| 84 | + if not isinstance(self.model, nn.Module): |
| 85 | + raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}") |
| 86 | + if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( |
| 87 | + self.tokenizer, PreTrainedTokenizer |
| 88 | + ): |
| 89 | + raise TypeError( |
| 90 | + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}" |
| 91 | + ) |
| 92 | + assert ( |
| 93 | + self.model.__class__.__name__ in _supported_models |
| 94 | + ), f"Model {self.model.__class__.__name__} is not supported." |
| 95 | + |
| 96 | + def _shardformer( |
| 97 | + self, |
| 98 | + model: nn.Module, |
| 99 | + model_policy: Policy, |
| 100 | + stage_manager: PipelineStageManager = None, |
| 101 | + tp_group: ProcessGroupMesh = None, |
| 102 | + ) -> nn.Module: |
| 103 | + """ |
| 104 | + Initialize ShardConfig and replace the model with shardformer. |
| 105 | +
|
| 106 | + Args: |
| 107 | + model (nn.Module): Path or nn.Module of this model. |
| 108 | + model_policy (Policy): The policy to shardformer model which is determined by the model type. |
| 109 | + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. |
| 110 | + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. |
| 111 | +
|
| 112 | + Returns: |
| 113 | + nn.Module: _description_ |
| 114 | + """ |
| 115 | + shardconfig = ShardConfig( |
| 116 | + tensor_parallel_process_group=tp_group, |
| 117 | + pipeline_stage_manager=stage_manager, |
| 118 | + enable_tensor_parallelism=(self.inference_config.tp_size > 1), |
| 119 | + enable_fused_normalization=False, |
| 120 | + enable_all_optimization=False, |
| 121 | + enable_flash_attention=False, |
| 122 | + enable_jit_fused=False, |
| 123 | + enable_sequence_parallelism=False, |
| 124 | + extra_kwargs={"quant": self.inference_config.quant_mode}, |
| 125 | + ) |
| 126 | + shardformer = ShardFormer(shard_config=shardconfig) |
| 127 | + shard_model, _ = shardformer.optimize(model, model_policy) |
| 128 | + return shard_model.cuda() |
42 | 129 |
|
43 |
| - def _init_model(self): |
| 130 | + def generate( |
| 131 | + self, |
| 132 | + generation_config: GenerationConfig = None, |
| 133 | + ) -> List[str]: |
44 | 134 | """
|
45 |
| - Initialize model and distributed training environment(if needed). |
46 |
| - May need to provide two different initialization methods: |
47 |
| - 1. 用户自定义(from local path) |
48 |
| - 2. 从checkpoint加载(hugging face) |
| 135 | + Executing the inference step. |
| 136 | +
|
| 137 | + Args: |
| 138 | + generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + List[str]: Inference result returned by one generation. |
49 | 142 | """
|
50 | 143 |
|
51 |
| - def _verify_config(self): |
| 144 | + self.generation_config = generation_config |
| 145 | + |
| 146 | + output_list = [] |
| 147 | + |
| 148 | + while self.request_handler.check_unfinished_seqs(): |
| 149 | + output_list += self.step() |
| 150 | + |
| 151 | + return output_list |
| 152 | + |
| 153 | + def add_request( |
| 154 | + self, |
| 155 | + requests_id: List[int] = None, |
| 156 | + prompts: List[str] = None, |
| 157 | + prompts_token_ids: List[int] = None, |
| 158 | + ) -> None: |
52 | 159 | """
|
53 |
| - Verify the configuration to avoid potential bugs. |
| 160 | + Add requests. |
| 161 | +
|
| 162 | + Args: |
| 163 | + requests_id (List[int], optional): The request ID. Defaults to None. |
| 164 | + prompts (Union[List[str], optional): Input prompts. Defaults to None. |
| 165 | + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. |
54 | 166 | """
|
55 | 167 |
|
56 |
| - def generate(self): |
57 |
| - pass |
| 168 | + block_size = self.inference_config.block_size |
58 | 169 |
|
59 |
| - def step(self): |
| 170 | + if prompts_token_ids is None: |
| 171 | + assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." |
| 172 | + prompts_token_ids = [] |
| 173 | + for prompt in prompts: |
| 174 | + prompts_token_ids.append(self.tokenizer.encode(prompt)) |
| 175 | + |
| 176 | + prompts_num = len(prompts_token_ids) |
| 177 | + |
| 178 | + for i in range(prompts_num): |
| 179 | + if requests_id: |
| 180 | + request_id = requests_id[i] |
| 181 | + else: |
| 182 | + request_id = next(self.counter) |
| 183 | + if prompts == None: |
| 184 | + prompt = None |
| 185 | + else: |
| 186 | + prompt = prompts[i] |
| 187 | + sequence = Sequence( |
| 188 | + request_id, |
| 189 | + prompt, |
| 190 | + prompts_token_ids[i], |
| 191 | + block_size, |
| 192 | + None, |
| 193 | + None, |
| 194 | + self.tokenizer.eos_token_id, |
| 195 | + self.inference_config.max_output_len, |
| 196 | + ) |
| 197 | + self.request_handler.add_sequence(sequence) |
| 198 | + |
| 199 | + def step(self) -> List[str]: |
60 | 200 | """
|
61 | 201 | In each step, do the follows:
|
62 |
| - 1. Run request_handler to update the kv cache and running input_ids |
| 202 | + 1. Run RequestHandler.schedule() and get the batch used for inference. |
63 | 203 | 2. Run model to generate the next token
|
64 |
| - 3. Check whether there is finied request and decode |
| 204 | + 3. Update waiting list and running list in RequestHandler and get finished sequences. |
| 205 | + 4. Decode and return finished sequences. |
| 206 | +
|
| 207 | + Returns: |
| 208 | + List[str]: Decoded finished sequences generated by one step. |
65 | 209 | """
|
| 210 | + |
| 211 | + if self.verbose: |
| 212 | + self.logger.info("Running generation step") |
| 213 | + |
| 214 | + output_list = [] |
| 215 | + self.request_handler.schedule() |
| 216 | + |
| 217 | + # Uncomment if the development of RequestHandler is completed. |
| 218 | + # logits = self.model(batch) |
| 219 | + # self.request_handler.search_tokens(logits, self.generation_config) |
| 220 | + |
| 221 | + finished_sequences = self.request_handler.update() |
| 222 | + |
| 223 | + # Decode completed sentences. |
| 224 | + for seq in finished_sequences: |
| 225 | + if seq.prompt: |
| 226 | + output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True) |
| 227 | + output_list.append(seq.prompt + output_str) |
| 228 | + else: |
| 229 | + output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True) |
| 230 | + output_list.append(output_str) |
| 231 | + |
| 232 | + return output_list |
0 commit comments