Skip to content

Commit 58740b5

Browse files
authored
[inference] added inference template (#5375)
1 parent 8106ede commit 58740b5

File tree

3 files changed

+65
-9
lines changed

3 files changed

+65
-9
lines changed

colossalai/inference/config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
2424

2525

26+
_DEFAULT_PROMPT_TEMPLATES = {
27+
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
28+
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
29+
}
30+
31+
2632
@dataclass
2733
class InferenceConfig:
2834
"""The inference configuration.
@@ -44,6 +50,7 @@ class InferenceConfig:
4450
pad_input: Whether to pad all inputs to the max length.
4551
quant_mode (Optional[str]): Quantization mode.
4652
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
53+
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text.
4754
"""
4855

4956
micro_batch_size: int = 1
@@ -62,6 +69,7 @@ class InferenceConfig:
6269
pad_input: bool = False
6370
quant_mode: Optional[str] = None
6471
revision: Optional[str] = None
72+
prompt_template: Optional[str] = None
6573

6674
def __post_init__(self):
6775
self._verify_config()
@@ -85,3 +93,15 @@ def _verify_config(self) -> None:
8593
assert (
8694
self.tp_size * self.pp_size == dist.get_world_size()
8795
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
96+
97+
# check prompt template
98+
if self.prompt_template is None:
99+
return
100+
101+
if self.prompt_template in _DEFAULT_PROMPT_TEMPLATES:
102+
self.prompt_template = _DEFAULT_PROMPT_TEMPLATES[self.prompt_template]
103+
else:
104+
# make sure the template can be formatted with input_text
105+
assert (
106+
"{input_text}" in self.prompt_template
107+
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"

colossalai/inference/core/engine.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,26 @@ def generate(
170170

171171
return output_str
172172

173+
@property
174+
def has_prompt_template(self) -> bool:
175+
""" """
176+
return self.inference_config.prompt_template is not None
177+
178+
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
179+
"""
180+
This method will format the input prompt according to the prompt template given to the InferenceConfig.
181+
"""
182+
assert (
183+
self.has_prompt_template
184+
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
185+
186+
if isinstance(prompts, (list, tuple)):
187+
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
188+
elif isinstance(prompts, str):
189+
return self.inference_config.rompt_template.format(input_text=prompts)
190+
else:
191+
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
192+
173193
def add_request(
174194
self,
175195
requests_id: List[int] = None,
@@ -185,6 +205,10 @@ def add_request(
185205
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
186206
"""
187207

208+
# apply the prompt template to the input prompts
209+
if self.has_prompt_template and prompts is not None:
210+
prompts = self.format_prompt(prompts)
211+
188212
block_size = self.inference_config.block_size
189213

190214
if prompts_token_ids is None:

tests/test_infer/test_inference_engine.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
77

88
import colossalai
9-
from colossalai.inference.config import InferenceConfig
9+
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
1010
from colossalai.inference.core.engine import InferenceEngine
11-
from colossalai.testing import rerun_if_address_is_in_use, spawn
11+
from colossalai.inference.flash_decoding_utils import FDIntermTensors
12+
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
1213

1314

1415
def setup_seed(seed):
@@ -18,7 +19,7 @@ def setup_seed(seed):
1819
random.seed(seed)
1920

2021

21-
def check_inference_engine(test_cai=False):
22+
def check_inference_engine(use_engine=False, prompt_template=None):
2223
setup_seed(20)
2324
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
2425
model = (
@@ -43,14 +44,17 @@ def check_inference_engine(test_cai=False):
4344
top_p = 0.5
4445
top_k = 50
4546

46-
if test_cai:
47-
inference_config = InferenceConfig(max_output_len=output_len)
47+
if use_engine:
48+
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
4849
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
4950
inference_engine.add_request(prompts=inputs)
5051
assert inference_engine.request_handler._has_waiting()
5152
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
5253
outputs = inference_engine.generate(generation_config=generation_config)
5354
else:
55+
if prompt_template:
56+
# apply prompt template
57+
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
5458
tokenizer.pad_token = tokenizer.eos_token
5559
tokenizer.pad_token_id = tokenizer.eos_token_id
5660
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
@@ -68,14 +72,22 @@ def check_inference_engine(test_cai=False):
6872
return outputs
6973

7074

71-
def run_dist(rank, world_size, port):
72-
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
73-
cai_outputs = check_inference_engine(True)
74-
transformer_outputs = check_inference_engine(False)
75+
@parameterize("prompt_template", [None, "llama"])
76+
def check_output_consistency(prompt_template):
77+
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
78+
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
7579

7680
for s1, s2 in zip(cai_outputs, transformer_outputs):
7781
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
7882

83+
# clear singleton flash decoding tensors
84+
FDIntermTensors._instances = {}
85+
86+
87+
def run_dist(rank, world_size, port):
88+
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
89+
check_output_consistency()
90+
7991

8092
@pytest.mark.dist
8193
@rerun_if_address_is_in_use()

0 commit comments

Comments
 (0)