Skip to content

Commit 8d0f7a9

Browse files
committed
[inference] refactored config
1 parent 1f8c7e7 commit 8d0f7a9

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

colossalai/inference/config.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,49 +35,60 @@ class InferenceConfig:
3535
"""The inference configuration.
3636
3737
Args:
38-
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
39-
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
4038
max_batch_size (int): Maximum batch size, defaults to 8.
4139
max_output_len (int): Maximum output length, defaults to 256.
4240
max_input_len (int): Maximum input length, defaults to 256.
43-
block_size (int): The number of blocks in a logical block, defaults to 16.
4441
dtype (Union[str, torch.dtype]): The data type for weights and activations.
45-
tp_size (int): Tensor parallel size, defaults to 1.
46-
pp_size (int): Pipeline parallel size, defaults to 1.
42+
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
43+
do_sample (bool): Whether to use sampling for generation, defaults to False.
4744
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
4845
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
4946
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
5047
when the actual value exceeds this ratio.
5148
pad_input: Whether to pad all inputs to the max length.
52-
quant_mode (Optional[str]): Quantization mode.
53-
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
54-
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text.
49+
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
50+
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
51+
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
52+
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
53+
block_size (int): The number of blocks in a logical block, defaults to 16.
54+
tp_size (int): Tensor parallel size, defaults to 1.
55+
pp_size (int): Pipeline parallel size, defaults to 1.
56+
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
57+
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
58+
5559
"""
5660

57-
micro_batch_size: int = 1
58-
micro_batch_buffer_size: int = None
61+
# NOTE: arrange configs according to their importance and frequency of usage
62+
63+
# runtime limit
5964
max_batch_size: int = 8
6065
max_output_len: int = 256
6166
max_input_len: int = 256
62-
block_size: int = 16
67+
68+
# general configs
6369
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
6470

65-
tp_size: int = 1
66-
pp_size: int = 1
67-
# TODO: beam search is not support for now
71+
# generation configs
72+
prompt_template: Optional[str] = None
6873
do_sample: bool = False
69-
beam_width: int = 1
70-
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
71-
prefill_ratio: Optional[float] = 1.2
74+
beam_width: int = 1 # TODO: beam search is not support for now
75+
prefill_ratio: Optional[
76+
float
77+
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
7278
pad_input: bool = False
73-
quant_mode: Optional[str] = None
74-
revision: Optional[str] = None
7579
early_stopping: Optional[bool] = False
76-
7780
top_k: Optional[int] = None
7881
top_p: Optional[float] = None
7982
min_p: Optional[float] = None
80-
prompt_template: Optional[str] = None
83+
84+
# paged attention configs
85+
block_size: int = 16
86+
87+
# model parallelism configs
88+
tp_size: int = 1
89+
pp_size: int = 1
90+
micro_batch_size: int = 1
91+
micro_batch_buffer_size: int = None
8192

8293
def __post_init__(self):
8394
self._verify_config()

colossalai/inference/core/engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def _shardformer(
130130
enable_flash_attention=False,
131131
enable_jit_fused=False,
132132
enable_sequence_parallelism=False,
133-
extra_kwargs={"quant": self.inference_config.quant_mode},
134133
)
135134
shardformer = ShardFormer(shard_config=shardconfig)
136135
shard_model, _ = shardformer.optimize(model, model_policy)

0 commit comments

Comments
 (0)