11# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py
2- from typing import List
2+ import logging
3+ from typing import List , Union
34
45import torch
56import torch .nn .functional as F
67
7- _LOGIT_PROCESSOR_MAP = {}
8+ _LOGITS_PROCESSOR_MAP = {}
89
910
10- def register_logit_processor (process_type ):
11+ def register_logits_processor (process_type ):
1112 """
1213 register flops computation function for operation.
1314 """
1415
1516 def register (func ):
16- global _LOGIT_PROCESSOR_MAP
17- _LOGIT_PROCESSOR_MAP [process_type ] = func
17+ global _LOGITS_PROCESSOR_MAP
18+ _LOGITS_PROCESSOR_MAP [process_type ] = func
1819 return func
1920
2021 return register
2122
2223
23- @register_logit_processor ("no_repeat_ngram_size" )
24+ @register_logits_processor ("no_repeat_ngram_size" )
2425def no_repeat_ngram_size_logit_process (logits , ngram_size : int , batch_token_ids : List [List [int ]]):
2526 """
2627 enforces no repetition of n-grams to avoid repetitions of word sequences.
@@ -52,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids:
5253 return logits
5354
5455
55- @register_logit_processor ("repetition_penalty" )
56+ @register_logits_processor ("repetition_penalty" )
5657def repetition_penalty_logit_process (logits , penalty : float , batch_token_ids : List [List [int ]]):
5758 """
5859 apply the penalty to the tokens present in the prompt.
@@ -78,8 +79,8 @@ def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: Li
7879 return logits
7980
8081
81- @register_logit_processor ("temperature" )
82- def temperature_logit_process (logits , temperature : float ):
82+ @register_logits_processor ("temperature" )
83+ def temperature_logits_process (logits , temperature : float ):
8384 """
8485 apply temperature scaling.
8586 """
@@ -93,8 +94,8 @@ def temperature_logit_process(logits, temperature: float):
9394 return logits if temperature == 1.0 else logits / temperature
9495
9596
96- @register_logit_processor ("top_k" )
97- def top_k_logit_processor (logits , top_k : int ):
97+ @register_logits_processor ("top_k" )
98+ def top_k_logits_processor (logits , top_k : int ):
9899 """
99100 top_k logit processor
100101 """
@@ -107,8 +108,8 @@ def top_k_logit_processor(logits, top_k: int):
107108 return logits
108109
109110
110- @register_logit_processor ("top_p" )
111- def top_p_logit_processor (logits , top_p : float ):
111+ @register_logits_processor ("top_p" )
112+ def top_p_logits_processor (logits , top_p : float ):
112113 """
113114 top_p logit processor
114115 """
@@ -129,7 +130,74 @@ def top_p_logit_processor(logits, top_p: float):
129130 return logits
130131
131132
132- def logit_processor (processor : str , logits , * args , ** kwargs ):
133+ @register_logits_processor ("forced_bos_token_id" )
134+ def forced_bos_token_processor (
135+ logits : torch .Tensor ,
136+ sequence_lengths : Union [torch .Tensor , List [int ]],
137+ max_out_lengths : Union [torch .Tensor , List [int ]],
138+ bos_token_id : int ,
139+ ):
140+ # NOTE For now, optimizations for encoder-decoder models have not been supported yet
141+ # And this function will never be called in the current implementation.
142+ if isinstance (sequence_lengths , torch .Tensor ):
143+ sequence_lengths = sequence_lengths .tolist ()
144+ if isinstance (max_out_lengths , torch .Tensor ):
145+ max_out_lengths = max_out_lengths .tolist ()
146+
147+ select_indexes = []
148+ num_sequences = logits .shape [0 ]
149+ sequence_lengths = sequence_lengths [:num_sequences ]
150+ max_out_lengths = max_out_lengths [:num_sequences ]
151+ for i , sequence_length in enumerate (sequence_lengths ):
152+ if sequence_length == 1 :
153+ select_indexes .append (i )
154+ if select_indexes :
155+ logits [select_indexes , :] = - float ("inf" )
156+ logits [select_indexes , bos_token_id ] = 0
157+
158+ return logits
159+
160+
161+ @register_logits_processor ("forced_eos_token_id" )
162+ def forced_eos_token_processor (
163+ logits : torch .Tensor ,
164+ sequence_lengths : Union [torch .Tensor , List [int ]],
165+ max_out_lengths : Union [torch .Tensor , List [int ]],
166+ eos_token_id : Union [int , List [int ]],
167+ ):
168+ """
169+ Enforces the specified token as the last generated token when the maximum output length
170+ is reached. Notice that the maximum output lengths for different sequences, even if they're
171+ in the same batch, can be different.
172+
173+ Args:
174+ logits(torch.Tensor): logits
175+ sequence_lengths(torch.Tensor): sequence lengths
176+ max_out_lengths(torch.Tensor): maximum output lengths for each sequence
177+ eos_token_id(Union[int, List[int]]): forced eos token id
178+ """
179+ if isinstance (eos_token_id , int ):
180+ eos_token_id = [eos_token_id ]
181+ if isinstance (sequence_lengths , torch .Tensor ):
182+ sequence_lengths = sequence_lengths .tolist ()
183+ if isinstance (max_out_lengths , torch .Tensor ):
184+ max_out_lengths = max_out_lengths .tolist ()
185+
186+ select_indexes = []
187+ num_sequences = logits .shape [0 ]
188+ sequence_lengths = sequence_lengths [:num_sequences ]
189+ max_out_lengths = max_out_lengths [:num_sequences ]
190+ for i , (sequence_length , max_out_length ) in enumerate (zip (sequence_lengths , max_out_lengths )):
191+ if sequence_length == max_out_length - 1 :
192+ select_indexes .append (i )
193+ if select_indexes :
194+ logits [select_indexes , :] = - float ("inf" )
195+ logits [select_indexes , eos_token_id ] = 0
196+
197+ return logits
198+
199+
200+ def logits_processor (processor : str , logits , * args , ** kwargs ):
133201 """
134202 do logit process for given logits.
135203
@@ -140,9 +208,10 @@ def logit_processor(processor: str, logits, *args, **kwargs):
140208 Returns:
141209 logits after process
142210 """
143- if processor not in _LOGIT_PROCESSOR_MAP :
144- return logits
211+ if processor not in _LOGITS_PROCESSOR_MAP :
212+ logging . warning ( f"Unsupported processor { processor } . Fall back to the original logits." )
145213 else :
146- func = _LOGIT_PROCESSOR_MAP [processor ]
214+ func = _LOGITS_PROCESSOR_MAP [processor ]
147215 logits = func (logits , * args , ** kwargs )
148- return logits
216+
217+ return logits
0 commit comments