1+ import logging
2+ from typing import List , Union
3+
14import torch
25import torch .nn .functional as F
36
4- _LOGIT_PROCESSOR_MAP = {}
7+ _LOGITS_PROCESSOR_MAP = {}
58
69
7- def register_logit_processor (process_type ):
10+ def register_logits_processor (process_type ):
811 """
912 register flops computation function for operation.
1013 """
1114
1215 def register (func ):
13- global _LOGIT_PROCESSOR_MAP
14- _LOGIT_PROCESSOR_MAP [process_type ] = func
16+ global _LOGITS_PROCESSOR_MAP
17+ _LOGITS_PROCESSOR_MAP [process_type ] = func
1518 return func
1619
1720 return register
1821
1922
20- @register_logit_processor ("temperature" )
21- def temperature_logit_process (logits , temperature : float ):
23+ @register_logits_processor ("temperature" )
24+ def temperature_logits_process (logits , temperature : float ):
2225 """
2326 apply temperature scaling.
2427 """
@@ -32,8 +35,8 @@ def temperature_logit_process(logits, temperature: float):
3235 return logits if temperature == 1.0 else logits / temperature
3336
3437
35- @register_logit_processor ("top_k" )
36- def top_k_logit_processor (logits , top_k : int ):
38+ @register_logits_processor ("top_k" )
39+ def top_k_logits_processor (logits , top_k : int ):
3740 """
3841 top_k logit processor
3942 """
@@ -46,8 +49,8 @@ def top_k_logit_processor(logits, top_k: int):
4649 return logits
4750
4851
49- @register_logit_processor ("top_p" )
50- def top_p_logit_processor (logits , top_p : float ):
52+ @register_logits_processor ("top_p" )
53+ def top_p_logits_processor (logits , top_p : float ):
5154 """
5255 top_p logit processor
5356 """
@@ -68,24 +71,88 @@ def top_p_logit_processor(logits, top_p: float):
6871 return logits
6972
7073
71- def logit_processor (processor : str , logits , attrs ):
74+ @register_logits_processor ("forced_bos_token_id" )
75+ def forced_bos_token_processor (
76+ logits : torch .Tensor ,
77+ sequence_lengths : Union [torch .Tensor , List [int ]],
78+ max_out_lengths : Union [torch .Tensor , List [int ]],
79+ bos_token_id : int ,
80+ ):
81+ # NOTE For now, optimizations for encoder-decoder models have not been supported yet
82+ # And this function will never be called in the current implementation.
83+ if isinstance (sequence_lengths , torch .Tensor ):
84+ sequence_lengths = sequence_lengths .tolist ()
85+ if isinstance (max_out_lengths , torch .Tensor ):
86+ max_out_lengths = max_out_lengths .tolist ()
87+
88+ select_indexes = []
89+ num_sequences = logits .shape [0 ]
90+ sequence_lengths = sequence_lengths [:num_sequences ]
91+ max_out_lengths = max_out_lengths [:num_sequences ]
92+ for i , sequence_length in enumerate (sequence_lengths ):
93+ if sequence_length == 1 :
94+ select_indexes .append (i )
95+ if select_indexes :
96+ logits [select_indexes , :] = - float ("inf" )
97+ logits [select_indexes , bos_token_id ] = 0
98+
99+ return logits
100+
101+
102+ @register_logits_processor ("forced_eos_token_id" )
103+ def forced_eos_token_processor (
104+ logits : torch .Tensor ,
105+ sequence_lengths : Union [torch .Tensor , List [int ]],
106+ max_out_lengths : Union [torch .Tensor , List [int ]],
107+ eos_token_id : Union [int , List [int ]],
108+ ):
109+ """
110+ Enforces the specified token as the last generated token when the maximum output length
111+ is reached. Notice that the maximum output lengths for different sequences, even if they're
112+ in the same batch, can be different.
113+
114+ Args:
115+ logits(torch.Tensor): logits
116+ sequence_lengths(torch.Tensor): sequence lengths
117+ max_out_lengths(torch.Tensor): maximum output lengths for each sequence
118+ eos_token_id(Union[int, List[int]]): forced eos token id
119+ """
120+ if isinstance (eos_token_id , int ):
121+ eos_token_id = [eos_token_id ]
122+ if isinstance (sequence_lengths , torch .Tensor ):
123+ sequence_lengths = sequence_lengths .tolist ()
124+ if isinstance (max_out_lengths , torch .Tensor ):
125+ max_out_lengths = max_out_lengths .tolist ()
126+
127+ select_indexes = []
128+ num_sequences = logits .shape [0 ]
129+ sequence_lengths = sequence_lengths [:num_sequences ]
130+ max_out_lengths = max_out_lengths [:num_sequences ]
131+ for i , (sequence_length , max_out_length ) in enumerate (zip (sequence_lengths , max_out_lengths )):
132+ if sequence_length == max_out_length - 1 :
133+ select_indexes .append (i )
134+ if select_indexes :
135+ logits [select_indexes , :] = - float ("inf" )
136+ logits [select_indexes , eos_token_id ] = 0
137+
138+ return logits
139+
140+
141+ def logits_processor (processor : str , logits , * args ):
72142 """
73143 do logit process for given logits.
74144
75145 Args:
76146 processor(str): the type of logit processor
77147 logits(torch.Tensor): input logits
78- attrs(dict): attrs of the logit processor
79148
80149 Returns:
81150 logits after process
82151 """
83- if processor not in _LOGIT_PROCESSOR_MAP :
84- return logits
152+ if processor not in _LOGITS_PROCESSOR_MAP :
153+ logging . warning ( f"Unsupported processor { processor } . Fall back to the original logits." )
85154 else :
86- func = _LOGIT_PROCESSOR_MAP [processor ]
87- try :
88- logits = func (logits , attrs )
89- except Exception :
90- return logits
91- return logits
155+ func = _LOGITS_PROCESSOR_MAP [processor ]
156+ logits = func (logits , * args )
157+
158+ return logits
0 commit comments