Skip to content

Commit 5d66594

Browse files
committed
refactor and add
1 parent 8bcfe36 commit 5d66594

File tree

3 files changed

+91
-18
lines changed

3 files changed

+91
-18
lines changed

colossalai/inference/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ class InferenceConfig(RPC_PARAM):
207207
temperature: Optional[float] = 1.0
208208
no_repeat_ngram_size: Optional[int] = 0
209209
repetition_penalty: Optional[float] = 1.0
210+
forced_bos_token_id: int = None
211+
forced_eos_token_id: int = None
210212

211213
# speculative decoding configs
212214
max_n_spec_tokens: int = 5

colossalai/inference/core/request_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from colossalai.inference.config import InferenceConfig
99
from colossalai.inference.flash_decoding_utils import FDIntermTensors
1010
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
11+
from colossalai.inference.logit_processors import logits_processor
12+
from colossalai.inference.sampler import *
1113
from colossalai.inference.struct import RequestStatus, Sequence
1214
from colossalai.logging import get_dist_logger
1315

colossalai/inference/logit_processors.py

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
2-
from typing import List
2+
import logging
3+
from typing import List, Union
34

45
import torch
56
import torch.nn.functional as F
67

7-
_LOGIT_PROCESSOR_MAP = {}
8+
_LOGITS_PROCESSOR_MAP = {}
89

910

10-
def register_logit_processor(process_type):
11+
def register_logits_processor(process_type):
1112
"""
1213
register flops computation function for operation.
1314
"""
1415

1516
def register(func):
16-
global _LOGIT_PROCESSOR_MAP
17-
_LOGIT_PROCESSOR_MAP[process_type] = func
17+
global _LOGITS_PROCESSOR_MAP
18+
_LOGITS_PROCESSOR_MAP[process_type] = func
1819
return func
1920

2021
return register
2122

2223

23-
@register_logit_processor("no_repeat_ngram_size")
24+
@register_logits_processor("no_repeat_ngram_size")
2425
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]):
2526
"""
2627
enforces no repetition of n-grams to avoid repetitions of word sequences.
@@ -52,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
5253
return logits
5354

5455

55-
@register_logit_processor("repetition_penalty")
56+
@register_logits_processor("repetition_penalty")
5657
def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]):
5758
"""
5859
apply the penalty to the tokens present in the prompt.
@@ -78,8 +79,8 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
7879
return logits
7980

8081

81-
@register_logit_processor("temperature")
82-
def temperature_logit_process(logits, temperature: float):
82+
@register_logits_processor("temperature")
83+
def temperature_logits_process(logits, temperature: float):
8384
"""
8485
apply temperature scaling.
8586
"""
@@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
9394
return logits if temperature == 1.0 else logits / temperature
9495

9596

96-
@register_logit_processor("top_k")
97-
def top_k_logit_processor(logits, top_k: int):
97+
@register_logits_processor("top_k")
98+
def top_k_logits_processor(logits, top_k: int):
9899
"""
99100
top_k logit processor
100101
"""
@@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
107108
return logits
108109

109110

110-
@register_logit_processor("top_p")
111-
def top_p_logit_processor(logits, top_p: float):
111+
@register_logits_processor("top_p")
112+
def top_p_logits_processor(logits, top_p: float):
112113
"""
113114
top_p logit processor
114115
"""
@@ -129,7 +130,74 @@ def top_p_logit_processor(logits, top_p: float):
129130
return logits
130131

131132

132-
def logit_processor(processor: str, logits, *args, **kwargs):
133+
@register_logits_processor("forced_bos_token_id")
134+
def forced_bos_token_processor(
135+
logits: torch.Tensor,
136+
sequence_lengths: Union[torch.Tensor, List[int]],
137+
max_out_lengths: Union[torch.Tensor, List[int]],
138+
bos_token_id: int,
139+
):
140+
# NOTE For now, optimizations for encoder-decoder models have not been supported yet
141+
# And this function will never be called in the current implementation.
142+
if isinstance(sequence_lengths, torch.Tensor):
143+
sequence_lengths = sequence_lengths.tolist()
144+
if isinstance(max_out_lengths, torch.Tensor):
145+
max_out_lengths = max_out_lengths.tolist()
146+
147+
select_indexes = []
148+
num_sequences = logits.shape[0]
149+
sequence_lengths = sequence_lengths[:num_sequences]
150+
max_out_lengths = max_out_lengths[:num_sequences]
151+
for i, sequence_length in enumerate(sequence_lengths):
152+
if sequence_length == 1:
153+
select_indexes.append(i)
154+
if select_indexes:
155+
logits[select_indexes, :] = -float("inf")
156+
logits[select_indexes, bos_token_id] = 0
157+
158+
return logits
159+
160+
161+
@register_logits_processor("forced_eos_token_id")
162+
def forced_eos_token_processor(
163+
logits: torch.Tensor,
164+
sequence_lengths: Union[torch.Tensor, List[int]],
165+
max_out_lengths: Union[torch.Tensor, List[int]],
166+
eos_token_id: Union[int, List[int]],
167+
):
168+
"""
169+
Enforces the specified token as the last generated token when the maximum output length
170+
is reached. Notice that the maximum output lengths for different sequences, even if they're
171+
in the same batch, can be different.
172+
173+
Args:
174+
logits(torch.Tensor): logits
175+
sequence_lengths(torch.Tensor): sequence lengths
176+
max_out_lengths(torch.Tensor): maximum output lengths for each sequence
177+
eos_token_id(Union[int, List[int]]): forced eos token id
178+
"""
179+
if isinstance(eos_token_id, int):
180+
eos_token_id = [eos_token_id]
181+
if isinstance(sequence_lengths, torch.Tensor):
182+
sequence_lengths = sequence_lengths.tolist()
183+
if isinstance(max_out_lengths, torch.Tensor):
184+
max_out_lengths = max_out_lengths.tolist()
185+
186+
select_indexes = []
187+
num_sequences = logits.shape[0]
188+
sequence_lengths = sequence_lengths[:num_sequences]
189+
max_out_lengths = max_out_lengths[:num_sequences]
190+
for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_out_lengths)):
191+
if sequence_length == max_out_length - 1:
192+
select_indexes.append(i)
193+
if select_indexes:
194+
logits[select_indexes, :] = -float("inf")
195+
logits[select_indexes, eos_token_id] = 0
196+
197+
return logits
198+
199+
200+
def logits_processor(processor: str, logits, *args, **kwargs):
133201
"""
134202
do logit process for given logits.
135203
@@ -140,9 +208,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
140208
Returns:
141209
logits after process
142210
"""
143-
if processor not in _LOGIT_PROCESSOR_MAP:
144-
return logits
211+
if processor not in _LOGITS_PROCESSOR_MAP:
212+
logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.")
145213
else:
146-
func = _LOGIT_PROCESSOR_MAP[processor]
214+
func = _LOGITS_PROCESSOR_MAP[processor]
147215
logits = func(logits, *args, **kwargs)
148-
return logits
216+
217+
return logits

0 commit comments

Comments
 (0)