Skip to content

Commit 4673a12

Browse files
njhilljimpang
authored andcommitted
[BugFix] Fix min_tokens behaviour for multiple eos tokens (vllm-project#5849)
1 parent 18bb15b commit 4673a12

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

vllm/engine/llm_engine.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -606,12 +606,9 @@ def _create_sequence_group_with_sampling(
606606
# Defensive copy of SamplingParams, which are used by the sampler,
607607
# this doesn't deep-copy LogitsProcessor objects
608608
sampling_params = sampling_params.clone()
609-
# Add the eos token id into the sampling_params to support min_tokens
610-
# processing
611-
if seq.eos_token_id is not None:
612-
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
609+
613610
sampling_params.update_from_generation_config(
614-
self.generation_config_fields)
611+
self.generation_config_fields, seq.eos_token_id)
615612

616613
# Create the sequence group.
617614
seq_group = SequenceGroup(

vllm/sampling_params.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,17 +280,30 @@ def _verify_greedy_sampling(self) -> None:
280280
f"Got {self.best_of}.")
281281

282282
def update_from_generation_config(
283-
self, generation_config: Dict[str, Any]) -> None:
283+
self,
284+
generation_config: Dict[str, Any],
285+
model_eos_token_id: Optional[int] = None) -> None:
284286
"""Update if there are non-default values from generation_config"""
287+
288+
if model_eos_token_id is not None:
289+
# Add the eos token id into the sampling_params to support
290+
# min_tokens processing.
291+
self.all_stop_token_ids.add(model_eos_token_id)
292+
285293
# Update eos_token_id for generation
286-
if (not self.ignore_eos) and (eos_ids :=
287-
generation_config.get("eos_token_id")):
294+
if (eos_ids := generation_config.get("eos_token_id")) is not None:
288295
# it can be either int or list of int
289-
if isinstance(eos_ids, int):
290-
eos_ids = [eos_ids]
291-
original_stop_token_ids = set(self.stop_token_ids)
292-
original_stop_token_ids.update(eos_ids)
293-
self.stop_token_ids = list(original_stop_token_ids)
296+
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
297+
if model_eos_token_id is not None:
298+
# We don't need to include the primary eos_token_id in
299+
# stop_token_ids since it's handled separately for stopping
300+
# purposes.
301+
eos_ids.discard(model_eos_token_id)
302+
if eos_ids:
303+
self.all_stop_token_ids.update(eos_ids)
304+
if not self.ignore_eos:
305+
eos_ids.update(self.stop_token_ids)
306+
self.stop_token_ids = list(eos_ids)
294307

295308
@cached_property
296309
def sampling_type(self) -> SamplingType:

0 commit comments

Comments
 (0)