Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]


_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
}


@dataclass
class InferenceConfig:
"""The inference configuration.
Expand All @@ -44,6 +50,7 @@ class InferenceConfig:
pad_input: Whether to pad all inputs to the max length.
quant_mode (Optional[str]): Quantization mode.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text.
"""

micro_batch_size: int = 1
Expand All @@ -62,6 +69,7 @@ class InferenceConfig:
pad_input: bool = False
quant_mode: Optional[str] = None
revision: Optional[str] = None
prompt_template: Optional[str] = None

def __post_init__(self):
self._verify_config()
Expand All @@ -85,3 +93,15 @@ def _verify_config(self) -> None:
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"

# check prompt template
if self.prompt_template is None:
return

if self.prompt_template in _DEFAULT_PROMPT_TEMPLATES:
self.prompt_template = _DEFAULT_PROMPT_TEMPLATES[self.prompt_template]
else:
# make sure the template can be formatted with input_text
assert (
"{input_text}" in self.prompt_template
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"
24 changes: 24 additions & 0 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ def generate(

return output_str

@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None

def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."

if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.rompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")

def add_request(
self,
requests_id: List[int] = None,
Expand All @@ -185,6 +205,10 @@ def add_request(
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""

# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)

block_size = self.inference_config.block_size

if prompts_token_ids is None:
Expand Down
30 changes: 21 additions & 9 deletions tests/test_infer/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM

import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


def setup_seed(seed):
Expand All @@ -18,7 +19,7 @@ def setup_seed(seed):
random.seed(seed)


def check_inference_engine(test_cai=False):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
Expand All @@ -43,14 +44,17 @@ def check_inference_engine(test_cai=False):
top_p = 0.5
top_k = 50

if test_cai:
inference_config = InferenceConfig(max_output_len=output_len)
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
Expand All @@ -68,14 +72,22 @@ def check_inference_engine(test_cai=False):
return outputs


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
cai_outputs = check_inference_engine(True)
transformer_outputs = check_inference_engine(False)
@parameterize("prompt_template", [None, "llama"])
def check_output_consistency(prompt_template):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)

for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"

# clear singleton flash decoding tensors
FDIntermTensors._instances = {}


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()


@pytest.mark.dist
@rerun_if_address_is_in_use()
Expand Down