Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def use_spec_dec(self) -> bool:
def num_tokens_to_verify(self) -> int:
return self._num_tokens_to_verify

@property
def batch_token_ids(self) -> List[List[int]]:
out = []
for seq in self.seqs_li:
out.append(seq.input_token_id + seq.output_token_id)
return out

def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
Expand Down Expand Up @@ -328,6 +335,7 @@ def pop_n_seqs(
seqs.append(seq)
if not self.is_compact:
self._make_compact()

return seqs, block_tables

def pop_finished(
Expand Down Expand Up @@ -432,6 +440,7 @@ def merge(self, other: "BatchBucket") -> List[int]:
block_tables = torch.stack(block_tables_li)
self.add_seqs(seqs, alloc_block_tables=block_tables)
unmerged_ids = other.seqs_ids

return unmerged_ids

########## The following methods are expected to be used in modeling ###########
Expand Down
10 changes: 7 additions & 3 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class InferenceConfig:
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
Expand Down Expand Up @@ -136,7 +138,9 @@ class InferenceConfig:
early_stopping: Optional[bool] = False
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
temperature: Optional[float] = 1.0
no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0

# speculative decoding configs
max_n_spec_tokens: int = 5
Expand Down Expand Up @@ -213,7 +217,7 @@ def to_generation_config(self, model_config) -> GenerationConfig:
"do_sample": self.do_sample,
"num_beams": self.beam_width,
}
for type in ["top_k", "top_p", "min_p"]:
for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]:
if hasattr(self, type):
meta_config[type] = getattr(self, type)
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def steps_spec_dec(self) -> List[Sequence]:

# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
Expand Down Expand Up @@ -472,7 +472,7 @@ def steps_spec_dec(self) -> List[Sequence]:
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)

next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)

# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
Expand Down Expand Up @@ -738,7 +738,7 @@ def step(self) -> List[str]:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()

Expand Down
15 changes: 9 additions & 6 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger

__all__ = ["RunningList", "RequestHandler"]

logger = get_dist_logger(__name__)


class RunningList:
"""
Expand Down Expand Up @@ -331,15 +328,21 @@ def check_unfinished_seqs(self) -> bool:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size

def search_tokens(self, generation_config: GenerationConfig, logits):
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
"""
Sample tokens for finished requests.
"""

# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], cur_batch)

# do logit processor
if generation_config.do_sample:
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
Expand Down
72 changes: 66 additions & 6 deletions colossalai/inference/logit_processors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py

import torch
import torch.nn.functional as F

from colossalai.inference.batch_bucket import BatchBucket

_LOGIT_PROCESSOR_MAP = {}


Expand All @@ -17,6 +21,66 @@ def register(func):
return register


@register_logit_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""

if not isinstance(ngram_size, int) or ngram_size < 0:
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")

if ngram_size != 0:
batch_token_ids = batch.batch_token_ids
batch_size = len(batch_token_ids)

for batch_id in range(batch_size):
current_token_ids = batch_token_ids[batch_id]
current_len = len(current_token_ids)
if current_len + 1 < ngram_size:
continue

ngrams_dict = {}

for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]

prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len])
banned_token = ngrams_dict.get(prev_ngrams, [])

logits[batch_id, banned_token] = -float("inf")

return logits


@register_logit_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
"""
apply the penalty to the tokens present in the prompt.
"""

if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")

logit_list = []

# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
batch_token_ids = batch.batch_token_ids
for batch_id in range(len(batch_token_ids)):
current_logit = logits[batch_id]
current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device)

curretn_socre = torch.gather(current_logit, 0, current_token)
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))

logits = torch.stack(logit_list)

return logits


@register_logit_processor("temperature")
def temperature_logit_process(logits, temperature: float):
"""
Expand Down Expand Up @@ -68,14 +132,13 @@ def top_p_logit_processor(logits, top_p: float):
return logits


def logit_processor(processor: str, logits, attrs):
def logit_processor(processor: str, logits, *args, **kwargs):
"""
do logit process for given logits.

Args:
processor(str): the type of logit processor
logits(torch.Tensor): input logits
attrs(dict): attrs of the logit processor

Returns:
logits after process
Expand All @@ -84,8 +147,5 @@ def logit_processor(processor: str, logits, attrs):
return logits
else:
func = _LOGIT_PROCESSOR_MAP[processor]
try:
logits = func(logits, attrs)
except Exception:
return logits
logits = func(logits, *args, **kwargs)
return logits