@@ -280,17 +280,30 @@ def _verify_greedy_sampling(self) -> None:
280
280
f"Got { self .best_of } ." )
281
281
282
282
def update_from_generation_config (
283
- self , generation_config : Dict [str , Any ]) -> None :
283
+ self ,
284
+ generation_config : Dict [str , Any ],
285
+ model_eos_token_id : Optional [int ] = None ) -> None :
284
286
"""Update if there are non-default values from generation_config"""
287
+
288
+ if model_eos_token_id is not None :
289
+ # Add the eos token id into the sampling_params to support
290
+ # min_tokens processing.
291
+ self .all_stop_token_ids .add (model_eos_token_id )
292
+
285
293
# Update eos_token_id for generation
286
- if (not self .ignore_eos ) and (eos_ids :=
287
- generation_config .get ("eos_token_id" )):
294
+ if (eos_ids := generation_config .get ("eos_token_id" )) is not None :
288
295
# it can be either int or list of int
289
- if isinstance (eos_ids , int ):
290
- eos_ids = [eos_ids ]
291
- original_stop_token_ids = set (self .stop_token_ids )
292
- original_stop_token_ids .update (eos_ids )
293
- self .stop_token_ids = list (original_stop_token_ids )
296
+ eos_ids = {eos_ids } if isinstance (eos_ids , int ) else set (eos_ids )
297
+ if model_eos_token_id is not None :
298
+ # We don't need to include the primary eos_token_id in
299
+ # stop_token_ids since it's handled separately for stopping
300
+ # purposes.
301
+ eos_ids .discard (model_eos_token_id )
302
+ if eos_ids :
303
+ self .all_stop_token_ids .update (eos_ids )
304
+ if not self .ignore_eos :
305
+ eos_ids .update (self .stop_token_ids )
306
+ self .stop_token_ids = list (eos_ids )
294
307
295
308
@cached_property
296
309
def sampling_type (self ) -> SamplingType :
0 commit comments