Skip to content

Commit 2d94b8c

Browse files
committed
Adapt repetition_penalty and no_repeat_ngram_size
1 parent 12e7c28 commit 2d94b8c

File tree

4 files changed

+136
-12
lines changed

4 files changed

+136
-12
lines changed

colossalai/inference/batch_bucket.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def __init__(
5454
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
5555
self._block_tables_helper = torch.full_like(self._block_tables, -1)
5656

57+
# 'batch_updated' is used as a flag to indicate whether there are additions or deletions of sequences in the current batch.
58+
self.batch_updated = True
59+
self._batch_prompt_ids = None
60+
5761
@property
5862
def is_empty(self):
5963
return self._current_batch_size == 0
@@ -99,6 +103,13 @@ def use_spec_dec(self) -> bool:
99103
def num_tokens_to_verify(self) -> int:
100104
return self._num_tokens_to_verify
101105

106+
@property
107+
def batch_token_ids(self):
108+
if self.batch_updated:
109+
self._batch_prompt_ids = self.get_batch_token_ids()
110+
self.batch_updated = False
111+
return self._batch_prompt_ids
112+
102113
def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
103114
"""Set batch bucket to use speculatvie decoding.
104115
This will notify the adjust the lengths of inputs during modeling,
@@ -167,6 +178,7 @@ def add_seq(
167178
elif alloc_block_table_fn:
168179
alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item())
169180
self._current_batch_size += 1
181+
self.batch_updated = True
170182
return block_table
171183

172184
def add_seqs(
@@ -218,6 +230,7 @@ def add_seqs(
218230

219231
self._current_batch_size += num_seqs_to_add
220232
seqs[:] = seqs[num_seqs_to_add:]
233+
self.batch_updated = True
221234

222235
return block_tables
223236

@@ -271,6 +284,7 @@ def pop_seq_update_batch(
271284
self._block_tables[0].fill_(-1)
272285
self._sequences_indexes.pop(request_id)
273286
self._current_batch_size -= 1
287+
self.batch_updated = True
274288

275289
return seq, block_table
276290

@@ -325,6 +339,9 @@ def pop_n_seqs(
325339
seqs.append(seq)
326340
if not self.is_compact:
327341
self._make_compact()
342+
343+
self.batch_updated = True
344+
328345
return seqs, block_tables
329346

330347
def pop_finished(
@@ -429,6 +446,8 @@ def merge(self, other: "BatchBucket") -> List[int]:
429446
block_tables = torch.stack(block_tables_li)
430447
self.add_seqs(seqs, alloc_block_tables=block_tables)
431448
unmerged_ids = other.seqs_ids
449+
self.batch_updated = True
450+
432451
return unmerged_ids
433452

434453
########## The following methods are expected to be used in modeling ###########
@@ -501,6 +520,14 @@ def get_sequence_lengths(self) -> torch.Tensor:
501520
sequence_lengths = self.seq_lengths[: self.current_batch_size]
502521
return sequence_lengths.to(device=self.device)
503522

523+
def get_batch_token_ids(self) -> List[torch.LongTensor]:
524+
assert self.is_compact # Debug usage
525+
out = []
526+
for seq_id, _ in self._sequences_indexes.items():
527+
seq: Sequence = self._sequences_dict[seq_id]
528+
out.append(torch.tensor(seq.input_token_id + seq.output_token_id, device=self.device))
529+
return out
530+
504531
# For compatibility
505532
@property
506533
def fd_inter_tensor(self) -> None:

colossalai/inference/config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ class InferenceConfig:
9999
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
100100
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
101101
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
102-
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
102+
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
103+
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
104+
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
103105
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
104106
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
105107
block_size (int): The number of blocks in a logical block, defaults to 16.
@@ -134,7 +136,9 @@ class InferenceConfig:
134136
early_stopping: Optional[bool] = False
135137
top_k: Optional[int] = None
136138
top_p: Optional[float] = None
137-
min_p: Optional[float] = None
139+
temperature: Optional[float] = 1.0
140+
no_repeat_ngram_size: Optional[int] = 0
141+
repetition_penalty: Optional[float] = 1.0
138142

139143
# speculative decoding configs
140144
max_n_spec_tokens: int = 5
@@ -204,7 +208,7 @@ def to_generation_config(self, model_config) -> GenerationConfig:
204208
"do_sample": self.do_sample,
205209
"num_beams": self.beam_width,
206210
}
207-
for type in ["top_k", "top_p", "min_p"]:
211+
for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]:
208212
if hasattr(self, type):
209213
meta_config[type] = getattr(self, type)
210214
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:

colossalai/inference/core/request_handler.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,24 @@ def search_tokens(self, generation_config: GenerationConfig, logits):
328328
"""
329329
Sample tokens for finished requests.
330330
"""
331-
# do logit processor
331+
332332
# NOTE: need to decide the granularity to process logits (sequence or batch)
333333
config_dict = generation_config.to_dict()
334-
for type in ["top_k", "top_p", "min_p"]:
334+
# process repetition_penalty
335+
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
335336
if type in config_dict and config_dict[type] is not None:
336-
logits = logit_processor(type, logits, config_dict[type])
337+
if not self.prefill_bb.is_empty:
338+
batch = self.prefill_bb
339+
else:
340+
batch = self.running_bb
341+
logits = logit_processor(type, logits, config_dict[type], batch.batch_token_ids)
342+
343+
# do logit processor
344+
if generation_config.do_sample:
345+
# process temperature, top_k, top_p
346+
for type in ["temperature", "top_k", "top_p"]:
347+
if type in config_dict and config_dict[type] is not None:
348+
logits = logit_processor(type, logits, config_dict[type])
337349

338350
# calculate probs
339351
probs = torch.softmax(logits, dim=-1, dtype=torch.float)

colossalai/inference/logit_processors.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
2+
from typing import List
3+
14
import torch
25
import torch.nn.functional as F
36

7+
from colossalai.logging import get_dist_logger
8+
49
_LOGIT_PROCESSOR_MAP = {}
10+
logger = get_dist_logger(__name__)
511

612

713
def register_logit_processor(process_type):
@@ -17,6 +23,81 @@ def register(func):
1723
return register
1824

1925

26+
@register_logit_processor("no_repeat_ngram_size")
27+
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[torch.LongTensor]):
28+
"""
29+
enforces no repetition of n-grams to avoid repetitions of word sequences.
30+
"""
31+
32+
if not isinstance(ngram_size, int) or ngram_size < 0:
33+
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")
34+
35+
if ngram_size != 0:
36+
batch_size = len(batch_token_ids)
37+
38+
for batch_id in range(batch_size):
39+
current_len = current_token_ids.size(0)
40+
if current_len + 1 < ngram_size:
41+
continue
42+
43+
current_token_ids = batch_token_ids[batch_id]
44+
token_ids_list = current_token_ids.tolist()
45+
46+
ngrams_dict = {}
47+
48+
for ngram in zip(*[token_ids_list[i:] for i in range(ngram_size)]):
49+
prev_ngram_tuple = tuple(ngram[:-1])
50+
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]
51+
52+
prev_ngrams = tuple(token_ids_list[current_len + 1 - ngram_size : current_len])
53+
banned_token = ngrams_dict.get(prev_ngrams, [])
54+
55+
logits[batch_id, banned_token] = -float("inf")
56+
57+
return logits
58+
59+
60+
@register_logit_processor("repetition_penalty")
61+
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: torch.LongTensor):
62+
"""
63+
apply the penalty to the tokens present in the prompt.
64+
"""
65+
66+
if not isinstance(penalty, float) or not (penalty > 0):
67+
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")
68+
69+
logit_list = []
70+
71+
# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
72+
if penalty != 1.0:
73+
for batch_id in range(len(batch_token_ids)):
74+
current_logit = logits[batch_id]
75+
current_token = batch_token_ids[batch_id]
76+
77+
curretn_socre = torch.gather(current_logit, 0, current_token)
78+
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
79+
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))
80+
81+
logits = torch.stack(logit_list)
82+
83+
return logits
84+
85+
86+
@register_logit_processor("temperature")
87+
def temperature_logit_process(logits, temperature: float):
88+
"""
89+
apply temperature scaling.
90+
"""
91+
92+
if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0):
93+
except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0."
94+
if temperature == 0.0:
95+
except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`."
96+
raise ValueError(except_msg)
97+
98+
return logits if temperature == 1.0 else logits / temperature
99+
100+
20101
@register_logit_processor("top_k")
21102
def top_k_logit_processor(logits, top_k: int):
22103
"""
@@ -45,14 +126,13 @@ def top_p_logit_processor(logits, top_p: float):
45126
return logits
46127

47128

48-
def logit_processor(processor: str, logits, attrs):
129+
def logit_processor(processor: str, logits, *args, **kwargs):
49130
"""
50131
do logit process for given logits.
51132
52133
Args:
53134
processor(str): the type of logit processor
54135
logits(torch.Tensor): input logits
55-
attrs(dict): attrs of the logit processor
56136
57137
Returns:
58138
logits after process
@@ -61,8 +141,9 @@ def logit_processor(processor: str, logits, attrs):
61141
return logits
62142
else:
63143
func = _LOGIT_PROCESSOR_MAP[processor]
64-
try:
65-
logits = func(logits, attrs)
66-
except Exception:
67-
return logits
144+
# try:
145+
logits = func(logits, *args, **kwargs)
146+
# except Exception as e:
147+
# logger.warning(f"An exception ({e}) occurred during the logit processing ({processor}), skip this logit processing step.")
148+
# return logits
68149
return logits

0 commit comments

Comments
 (0)