Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,12 @@ class InferenceConfig(RPC_PARAM):
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
pad_input: bool = False
early_stopping: Optional[bool] = False
top_k: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = 50
top_p: Optional[float] = 1.0
temperature: Optional[float] = 1.0
no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
forced_eos_token_id: int = None

# speculative decoding configs
max_n_spec_tokens: int = 5
Expand Down
22 changes: 13 additions & 9 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self.init_model(model_or_path, model_policy)

self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()

self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
Expand Down Expand Up @@ -524,12 +525,13 @@ def generate(
Returns:
List[str]: Inference result returned by one generation.
"""

gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids

with torch.inference_mode():
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
request_ids=request_ids,
prompts=prompts,
Expand All @@ -543,6 +545,7 @@ def generate(
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict

if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
Expand Down Expand Up @@ -688,11 +691,12 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
)

batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids

# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
Expand Down
7 changes: 6 additions & 1 deletion colossalai/inference/core/rpc_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,12 @@ async def step_(self, input_token_ids, input_meta_data: InputMetaData):
assert len(self.workers) == self.tp_size, "init workers first"

init_tasks = [
self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param())
self.async_parallel_wrapper(
worker.execute_model_forward,
input_token_ids,
input_meta_data.to_rpc_param(),
self.generation_config_dict,
)
for worker in self.workers
]
ret = await asyncio.gather(*init_tasks)
Expand Down
6 changes: 4 additions & 2 deletions colossalai/inference/executor/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]
)
logger.info("physical cache init over")

def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict):
def exposed_execute_model_forward(
self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict
):
# prepare the data for model forward
input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param)
input_meta_data.fd_inter_tensor = self.fd_inter_tensor
Expand All @@ -120,7 +122,7 @@ def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = search_tokens(
self.inference_config.to_generation_config(self.model_config),
generation_config_param,
logits,
input_meta_data.is_prompts,
input_meta_data.batch_token_ids,
Expand Down
87 changes: 64 additions & 23 deletions colossalai/inference/logit_processors.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
from typing import List
import logging
from typing import List, Union

import torch
import torch.nn.functional as F

_LOGIT_PROCESSOR_MAP = {}
_LOGITS_PROCESSOR_MAP = {}


def register_logit_processor(process_type):
def register_logits_processor(process_type):
"""
register flops computation function for operation.
"""

def register(func):
global _LOGIT_PROCESSOR_MAP
_LOGIT_PROCESSOR_MAP[process_type] = func
global _LOGITS_PROCESSOR_MAP
_LOGITS_PROCESSOR_MAP[process_type] = func
return func

return register


@register_logit_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
@register_logits_processor("no_repeat_ngram_size")
def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""
Expand Down Expand Up @@ -52,16 +53,16 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
return logits


@register_logit_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
@register_logits_processor("repetition_penalty")
def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]):
"""
apply the penalty to the tokens present in the prompt.
"""

if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")

logit_list = []
logits_list = []

# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
Expand All @@ -71,15 +72,15 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li

curretn_socre = torch.gather(current_logit, 0, current_token)
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
logits_list.append(current_logit.scatter(0, current_token, curretn_socre))

logits = torch.stack(logit_list)
logits = torch.stack(logits_list)

return logits


@register_logit_processor("temperature")
def temperature_logit_process(logits, temperature: float):
@register_logits_processor("temperature")
def apply_temperature(logits, temperature: float):
"""
apply temperature scaling.
"""
Expand All @@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
return logits if temperature == 1.0 else logits / temperature


@register_logit_processor("top_k")
def top_k_logit_processor(logits, top_k: int):
@register_logits_processor("top_k")
def apply_top_k(logits, top_k: int):
"""
top_k logit processor
"""
Expand All @@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
return logits


@register_logit_processor("top_p")
def top_p_logit_processor(logits, top_p: float):
@register_logits_processor("top_p")
def apply_top_p(logits, top_p: float):
"""
top_p logit processor
"""
Expand All @@ -129,7 +130,46 @@ def top_p_logit_processor(logits, top_p: float):
return logits


def logit_processor(processor: str, logits, *args, **kwargs):
@register_logits_processor("forced_eos_token_id")
def apply_forced_eos_token_id(
logits: torch.Tensor,
sequence_lengths: Union[torch.Tensor, List[int]],
max_lengths: Union[torch.Tensor, List[int]],
eos_token_id: Union[int, List[int]],
):
"""
Enforces the specified token as the last generated token when the maximum output length
is reached. Notice that the maximum output lengths for different sequences, even if they're
in the same batch, can be different.

Args:
logits(torch.Tensor): logits
sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
max_lengths(torch.Tensor): the maximum length for each sequence
eos_token_id(Union[int, List[int]]): forced eos token id
"""
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if isinstance(sequence_lengths, torch.Tensor):
sequence_lengths = sequence_lengths.tolist()
if isinstance(max_lengths, torch.Tensor):
max_lengths = max_lengths.tolist()

select_indexes = []
num_sequences = logits.shape[0]
sequence_lengths = sequence_lengths[:num_sequences]
max_lengths = max_lengths[:num_sequences]
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)):
if sequence_length == max_out_length - 1:
select_indexes.append(i)
if select_indexes:
logits[select_indexes, :] = -float("inf")
logits[select_indexes, eos_token_id] = 0

return logits


def get_logits_processor(processor: str, logits, *args, **kwargs):
"""
do logit process for given logits.

Expand All @@ -140,9 +180,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
Returns:
logits after process
"""
if processor not in _LOGIT_PROCESSOR_MAP:
return logits
if processor not in _LOGITS_PROCESSOR_MAP:
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
else:
func = _LOGIT_PROCESSOR_MAP[processor]
func = _LOGITS_PROCESSOR_MAP[processor]
logits = func(logits, *args, **kwargs)
return logits

return logits
67 changes: 35 additions & 32 deletions colossalai/inference/sampler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from transformers.generation import GenerationConfig

from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.logit_processors import get_logits_processor


def greedy_sample(
generation_config,
logprobs: torch.Tensor,
) -> torch.Tensor:
"""
Expand All @@ -18,7 +17,6 @@ def greedy_sample(


def multinomial_sample(
generation_config,
probs: torch.Tensor,
) -> torch.Tensor:
"""
Expand All @@ -29,7 +27,7 @@ def multinomial_sample(


def beam_search_sample(
generation_config,
beam_width: int,
logprobs: torch.Tensor,
is_prompt: bool = False,
) -> List[Tuple[List[int], List[int]]]:
Expand All @@ -46,7 +44,6 @@ def beam_search_sample(
# NOTE: this beam search sample function is wrong now.
"""

beam_width = generation_config.num_beams
results = []
if is_prompt:
# Prompt phase.
Expand All @@ -64,20 +61,8 @@ def beam_search_sample(
return results


def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt)

return sample_tokens


def search_tokens(
generation_config: GenerationConfig,
generation_config: Union[GenerationConfig, dict],
logits,
is_prompt: bool = False,
batch_token_ids: Optional[List[List[int]]] = None,
Expand All @@ -86,23 +71,41 @@ def search_tokens(
Sample tokens for finished requests.
"""
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], batch_token_ids)

# do logit processor
if generation_config.do_sample:
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])

# convert GenerationConfig to dict
# temporary fix for compatibility with the usage of RPCInferenceEngine
if isinstance(generation_config, GenerationConfig):
generation_config = generation_config.to_dict()

if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0:
logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids)
if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0:
logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids)
if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None:
sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))]
max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))]
logits = get_logits_processor(
"forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id
)

if generation_config.get("do_sample"):
if (temperature := generation_config.get("temperature", 1.0)) != 1.0:
logits = get_logits_processor("temperature", logits, temperature)
if (top_k := generation_config.get("top_k", 0)) != 0:
logits = get_logits_processor("top_k", logits, top_k)
if (top_p := generation_config.get("top_p", 1.0)) < 1.0:
logits = get_logits_processor("top_p", logits, top_p)

# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# sample the next tokens
sample_tokens = _sample(probs, logprobs, generation_config, is_prompt)
if generation_config.get("num_beams", 1) != 1:
raise NotImplementedError("Beam search is not supported yet.")
if generation_config.get("do_sample", False):
sample_tokens = multinomial_sample(probs)
else:
sample_tokens = greedy_sample(logprobs)

return sample_tokens