Skip to content

Commit fa5a295

Browse files
njhillAlvant
authored andcommitted
[BugFix] Fix min_tokens behaviour for multiple eos tokens (vllm-project#5849)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 7679f1a commit fa5a295

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

vllm/engine/llm_engine.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -606,12 +606,9 @@ def _create_sequence_group_with_sampling(
606606
# Defensive copy of SamplingParams, which are used by the sampler,
607607
# this doesn't deep-copy LogitsProcessor objects
608608
sampling_params = sampling_params.clone()
609-
# Add the eos token id into the sampling_params to support min_tokens
610-
# processing
611-
if seq.eos_token_id is not None:
612-
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
609+
613610
sampling_params.update_from_generation_config(
614-
self.generation_config_fields)
611+
self.generation_config_fields, seq.eos_token_id)
615612

616613
# Create the sequence group.
617614
seq_group = SequenceGroup(

vllm/sampling_params.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -364,17 +364,30 @@ def _verify_greedy_sampling(self) -> None:
364364
f"Got {self.best_of}.")
365365

366366
def update_from_generation_config(
367-
self, generation_config: Dict[str, Any]) -> None:
367+
self,
368+
generation_config: Dict[str, Any],
369+
model_eos_token_id: Optional[int] = None) -> None:
368370
"""Update if there are non-default values from generation_config"""
371+
372+
if model_eos_token_id is not None:
373+
# Add the eos token id into the sampling_params to support
374+
# min_tokens processing.
375+
self.all_stop_token_ids.add(model_eos_token_id)
376+
369377
# Update eos_token_id for generation
370-
if (not self.ignore_eos) and (eos_ids :=
371-
generation_config.get("eos_token_id")):
378+
if (eos_ids := generation_config.get("eos_token_id")) is not None:
372379
# it can be either int or list of int
373-
if isinstance(eos_ids, int):
374-
eos_ids = [eos_ids]
375-
original_stop_token_ids = set(self.stop_token_ids)
376-
original_stop_token_ids.update(eos_ids)
377-
self.stop_token_ids = list(original_stop_token_ids)
380+
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
381+
if model_eos_token_id is not None:
382+
# We don't need to include the primary eos_token_id in
383+
# stop_token_ids since it's handled separately for stopping
384+
# purposes.
385+
eos_ids.discard(model_eos_token_id)
386+
if eos_ids:
387+
self.all_stop_token_ids.update(eos_ids)
388+
if not self.ignore_eos:
389+
eos_ids.update(self.stop_token_ids)
390+
self.stop_token_ids = list(eos_ids)
378391

379392
@cached_property
380393
def sampling_type(self) -> SamplingType:

0 commit comments

Comments
 (0)