@@ -364,17 +364,30 @@ def _verify_greedy_sampling(self) -> None:
364364 f"Got { self .best_of } ." )
365365
366366 def update_from_generation_config (
367- self , generation_config : Dict [str , Any ]) -> None :
367+ self ,
368+ generation_config : Dict [str , Any ],
369+ model_eos_token_id : Optional [int ] = None ) -> None :
368370 """Update if there are non-default values from generation_config"""
371+
372+ if model_eos_token_id is not None :
373+ # Add the eos token id into the sampling_params to support
374+ # min_tokens processing.
375+ self .all_stop_token_ids .add (model_eos_token_id )
376+
369377 # Update eos_token_id for generation
370- if (not self .ignore_eos ) and (eos_ids :=
371- generation_config .get ("eos_token_id" )):
378+ if (eos_ids := generation_config .get ("eos_token_id" )) is not None :
372379 # it can be either int or list of int
373- if isinstance (eos_ids , int ):
374- eos_ids = [eos_ids ]
375- original_stop_token_ids = set (self .stop_token_ids )
376- original_stop_token_ids .update (eos_ids )
377- self .stop_token_ids = list (original_stop_token_ids )
380+ eos_ids = {eos_ids } if isinstance (eos_ids , int ) else set (eos_ids )
381+ if model_eos_token_id is not None :
382+ # We don't need to include the primary eos_token_id in
383+ # stop_token_ids since it's handled separately for stopping
384+ # purposes.
385+ eos_ids .discard (model_eos_token_id )
386+ if eos_ids :
387+ self .all_stop_token_ids .update (eos_ids )
388+ if not self .ignore_eos :
389+ eos_ids .update (self .stop_token_ids )
390+ self .stop_token_ids = list (eos_ids )
378391
379392 @cached_property
380393 def sampling_type (self ) -> SamplingType :
0 commit comments