@@ -134,20 +134,20 @@ def apply_top_p(logits, top_p: float):
134134def apply_forced_bos_token_id (
135135 logits : torch .Tensor ,
136136 sequence_lengths : Union [torch .Tensor , List [int ]],
137- max_out_lengths : Union [torch .Tensor , List [int ]],
137+ max_lengths : Union [torch .Tensor , List [int ]],
138138 bos_token_id : int ,
139139):
140140 # NOTE For now, optimizations for encoder-decoder models have not been supported yet
141141 # And this function will never be called in the current implementation.
142142 if isinstance (sequence_lengths , torch .Tensor ):
143143 sequence_lengths = sequence_lengths .tolist ()
144- if isinstance (max_out_lengths , torch .Tensor ):
145- max_out_lengths = max_out_lengths .tolist ()
144+ if isinstance (max_lengths , torch .Tensor ):
145+ max_lengths = max_lengths .tolist ()
146146
147147 select_indexes = []
148148 num_sequences = logits .shape [0 ]
149149 sequence_lengths = sequence_lengths [:num_sequences ]
150- max_out_lengths = max_out_lengths [:num_sequences ]
150+ max_lengths = max_lengths [:num_sequences ]
151151 for i , sequence_length in enumerate (sequence_lengths ):
152152 if sequence_length == 1 :
153153 select_indexes .append (i )
@@ -162,7 +162,7 @@ def apply_forced_bos_token_id(
162162def apply_forced_eos_token_id (
163163 logits : torch .Tensor ,
164164 sequence_lengths : Union [torch .Tensor , List [int ]],
165- max_out_lengths : Union [torch .Tensor , List [int ]],
165+ max_lengths : Union [torch .Tensor , List [int ]],
166166 eos_token_id : Union [int , List [int ]],
167167):
168168 """
@@ -172,22 +172,22 @@ def apply_forced_eos_token_id(
172172
173173 Args:
174174 logits(torch.Tensor): logits
175- sequence_lengths(torch.Tensor): sequence lengths
176- max_out_lengths (torch.Tensor): maximum output lengths for each sequence
175+ sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens
176+ max_lengths (torch.Tensor): the maximum length for each sequence
177177 eos_token_id(Union[int, List[int]]): forced eos token id
178178 """
179179 if isinstance (eos_token_id , int ):
180180 eos_token_id = [eos_token_id ]
181181 if isinstance (sequence_lengths , torch .Tensor ):
182182 sequence_lengths = sequence_lengths .tolist ()
183- if isinstance (max_out_lengths , torch .Tensor ):
184- max_out_lengths = max_out_lengths .tolist ()
183+ if isinstance (max_lengths , torch .Tensor ):
184+ max_lengths = max_lengths .tolist ()
185185
186186 select_indexes = []
187187 num_sequences = logits .shape [0 ]
188188 sequence_lengths = sequence_lengths [:num_sequences ]
189- max_out_lengths = max_out_lengths [:num_sequences ]
190- for i , (sequence_length , max_out_length ) in enumerate (zip (sequence_lengths , max_out_lengths )):
189+ max_lengths = max_lengths [:num_sequences ]
190+ for i , (sequence_length , max_out_length ) in enumerate (zip (sequence_lengths , max_lengths )):
191191 if sequence_length == max_out_length - 1 :
192192 select_indexes .append (i )
193193 if select_indexes :
0 commit comments