Skip to content

Commit e46a60a

Browse files
authored
[BugFix] Fix handling of stop strings and stop token ids (vllm-project#3672)
1 parent 1e96c33 commit e46a60a

File tree

8 files changed

+202
-37
lines changed

8 files changed

+202
-37
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def __del__(self):
401401
cleanup()
402402

403403

404-
@pytest.fixture
404+
@pytest.fixture(scope="session")
405405
def vllm_runner():
406406
return VllmRunner
407407

tests/samplers/test_stop_reason.py renamed to tests/engine/test_stop_reason.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
2. One of the provided stop tokens
44
3. The EOS token
55
6-
Run `pytest tests/samplers/test_stop_reason.py`.
6+
Run `pytest tests/engine/test_stop_reason.py`.
77
"""
88

99
import pytest

tests/engine/test_stop_strings.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Any, List, Optional
2+
3+
import pytest
4+
5+
from vllm import CompletionOutput, LLMEngine, SamplingParams
6+
7+
MODEL = "meta-llama/llama-2-7b-hf"
8+
MAX_TOKENS = 200
9+
10+
11+
@pytest.fixture(scope="session")
12+
def vllm_model(vllm_runner):
13+
return vllm_runner(MODEL)
14+
15+
16+
@pytest.mark.skip_global_cleanup
17+
def test_stop_basic(vllm_model):
18+
_test_stopping(vllm_model.model.llm_engine,
19+
stop=["."],
20+
include_in_output=False,
21+
expected_output="VLLM is a 100% volunteer organization",
22+
expected_reason=".")
23+
24+
_test_stopping(vllm_model.model.llm_engine,
25+
stop=["."],
26+
include_in_output=True,
27+
expected_output="VLLM is a 100% volunteer organization.",
28+
expected_reason=".")
29+
30+
31+
@pytest.mark.skip_global_cleanup
32+
def test_stop_multi_tokens(vllm_model):
33+
_test_stopping(
34+
vllm_model.model.llm_engine,
35+
stop=["group of peo", "short"],
36+
include_in_output=False,
37+
expected_output="VLLM is a 100% volunteer organization. We are a ",
38+
expected_reason="group of peo")
39+
40+
_test_stopping(
41+
vllm_model.model.llm_engine,
42+
stop=["group of peo", "short"],
43+
include_in_output=True,
44+
expected_output=
45+
"VLLM is a 100% volunteer organization. We are a group of peo",
46+
expected_reason="group of peo")
47+
48+
49+
@pytest.mark.skip_global_cleanup
50+
def test_stop_partial_token(vllm_model):
51+
_test_stopping(vllm_model.model.llm_engine,
52+
stop=["gani"],
53+
include_in_output=False,
54+
expected_output="VLLM is a 100% volunteer or",
55+
expected_reason="gani")
56+
57+
_test_stopping(vllm_model.model.llm_engine,
58+
stop=["gani"],
59+
include_in_output=True,
60+
expected_output="VLLM is a 100% volunteer organi",
61+
expected_reason="gani")
62+
63+
64+
@pytest.mark.skip_global_cleanup
65+
def test_stop_token_id(vllm_model):
66+
# token id 13013 => " organization"
67+
68+
_test_stopping(vllm_model.model.llm_engine,
69+
stop_token_ids=[13013],
70+
include_in_output=False,
71+
expected_output="VLLM is a 100% volunteer",
72+
expected_reason=13013)
73+
74+
_test_stopping(vllm_model.model.llm_engine,
75+
stop_token_ids=[13013],
76+
include_in_output=True,
77+
expected_output="VLLM is a 100% volunteer organization",
78+
expected_reason=13013)
79+
80+
81+
def _test_stopping(llm_engine: LLMEngine,
82+
expected_output: str,
83+
expected_reason: Any,
84+
stop: Optional[List[str]] = None,
85+
stop_token_ids: Optional[List[int]] = None,
86+
include_in_output: bool = False) -> None:
87+
llm_engine.add_request(
88+
"id", "A story about vLLM:\n",
89+
SamplingParams(
90+
temperature=0.0,
91+
max_tokens=MAX_TOKENS,
92+
stop=stop,
93+
stop_token_ids=stop_token_ids,
94+
include_stop_str_in_output=include_in_output,
95+
), None)
96+
97+
output: Optional[CompletionOutput] = None
98+
output_text = ""
99+
stop_reason = None
100+
while llm_engine.has_unfinished_requests():
101+
(request_output, ) = llm_engine.step()
102+
(output, ) = request_output.outputs
103+
104+
# Ensure we don't backtrack
105+
assert output.text.startswith(output_text)
106+
output_text = output.text
107+
stop_reason = output.stop_reason
108+
109+
assert output is not None
110+
assert output_text == expected_output
111+
assert stop_reason == expected_reason

vllm/engine/llm_engine.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
501501

502502
for seq, _ in child_seqs:
503503
if seq_group.sampling_params.detokenize:
504-
self.detokenizer.decode_sequence_inplace(
504+
new_char_count = self.detokenizer.decode_sequence_inplace(
505505
seq, seq_group.sampling_params)
506-
self._check_stop(seq, seq_group.sampling_params)
506+
else:
507+
new_char_count = 0
508+
self._check_stop(seq, new_char_count, seq_group.sampling_params)
507509

508510
# Non-beam search case
509511
if not seq_group.sampling_params.use_beam_search:
@@ -798,56 +800,86 @@ def _get_stats(self,
798800
time_e2e_requests=time_e2e_requests,
799801
)
800802

801-
def _check_stop(self, seq: Sequence,
803+
def _check_stop(self, seq: Sequence, new_char_count: int,
802804
sampling_params: SamplingParams) -> None:
803-
"""Stop the finished sequences."""
804-
# Check if the sequence has reached max_model_len.
805-
if seq.get_len() > self.scheduler_config.max_model_len:
806-
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
807-
return
805+
"""Stop the finished sequences.
808806
809-
# Check if the sequence has reached max_tokens.
810-
if seq.get_output_len() == sampling_params.max_tokens:
811-
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
812-
return
807+
new_char_count is the number of chars added to the
808+
sequence's output text for the newly generated token
809+
"""
813810

814811
# Check if the minimum number of tokens has been generated yet;
815812
# skip the stop string/token checks if not
816813
if seq.get_output_len() < sampling_params.min_tokens:
817814
return
818815

819-
if sampling_params.detokenize:
820-
for stop_str in sampling_params.stop:
821-
if seq.output_text.endswith(stop_str):
822-
self._finalize_sequence(seq, sampling_params, stop_str)
823-
seq.status = SequenceStatus.FINISHED_STOPPED
824-
seq.stop_reason = stop_str
825-
return
816+
# Check if the sequence has generated the EOS token.
817+
if ((not sampling_params.ignore_eos)
818+
and seq.get_last_token_id() == seq.eos_token_id):
819+
seq.status = SequenceStatus.FINISHED_STOPPED
820+
return
821+
822+
# Check if a stop token was encountered.
823+
# This assumes a single token produced per step.
826824
last_token_id = seq.get_last_token_id()
827825
if last_token_id in sampling_params.stop_token_ids:
828-
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
829-
last_token_id)
830-
self._finalize_sequence(seq, sampling_params, stop_str)
826+
if new_char_count and (
827+
not sampling_params.include_stop_str_in_output):
828+
# Remove last token
829+
seq.output_text = seq.output_text[:-new_char_count]
831830
seq.status = SequenceStatus.FINISHED_STOPPED
832831
seq.stop_reason = last_token_id
833832
return
834833

835-
# Check if the sequence has generated the EOS token.
836-
if ((not sampling_params.ignore_eos)
837-
and seq.get_last_token_id() == seq.eos_token_id):
834+
# Check if any stop strings are matched.
835+
stop_str = self._check_stop_strings(seq, new_char_count,
836+
sampling_params)
837+
if stop_str is not None:
838838
seq.status = SequenceStatus.FINISHED_STOPPED
839+
seq.stop_reason = stop_str
839840
return
840841

841-
def _finalize_sequence(self, seq: Sequence,
842-
sampling_params: SamplingParams,
843-
stop_string: str) -> None:
844-
if sampling_params.include_stop_str_in_output:
842+
# Check if the sequence has reached max_model_len.
843+
if seq.get_len() > self.scheduler_config.max_model_len:
844+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
845845
return
846846

847-
if stop_string and seq.output_text.endswith(stop_string):
848-
# Truncate the output text so that the stop string is
849-
# not included in the output.
850-
seq.output_text = seq.output_text[:-len(stop_string)]
847+
# Check if the sequence has reached max_tokens.
848+
if seq.get_output_len() == sampling_params.max_tokens:
849+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
850+
return
851+
852+
@staticmethod
853+
def _check_stop_strings(seq: Sequence, new_char_count: int,
854+
sampling_params: SamplingParams) -> Optional[str]:
855+
"""Check if any stop strings are matched and truncate sequence
856+
output text accordingly.
857+
858+
Returns the stop string if matched or else None.
859+
"""
860+
if not new_char_count:
861+
return None
862+
863+
for stop_str in sampling_params.stop:
864+
stop_string_len = len(stop_str)
865+
# Avoid searching already-searched text.
866+
stop_index = seq.output_text.find(
867+
stop_str, -new_char_count - stop_string_len)
868+
if stop_index == -1:
869+
continue
870+
871+
if sampling_params.include_stop_str_in_output:
872+
# Truncate to end of stop string.
873+
stop_index += stop_string_len
874+
if stop_index >= len(seq.output_text):
875+
# No truncation required.
876+
return stop_str
877+
878+
# Truncate the output text to either the beginning
879+
# or end of the stop string.
880+
seq.output_text = seq.output_text[:stop_index]
881+
return stop_str
882+
return None
851883

852884
def add_lora(self, lora_request: LoRARequest) -> bool:
853885
return self.model_executor.add_lora(lora_request)

vllm/outputs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
112112
# always has the logprobs of the sampled tokens even if the
113113
# logprobs are not requested.
114114
include_logprobs = seq_group.sampling_params.logprobs is not None
115+
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
115116
outputs = [
116-
CompletionOutput(seqs.index(seq), seq.output_text,
117+
CompletionOutput(seqs.index(seq),
118+
seq.get_output_text_to_return(text_buffer_length),
117119
seq.get_output_token_ids(),
118120
seq.get_cumulative_logprob(),
119121
seq.output_logprobs if include_logprobs else None,

vllm/sampling_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ def __init__(
166166
self.logits_processors = logits_processors
167167
self.include_stop_str_in_output = include_stop_str_in_output
168168
self.truncate_prompt_tokens = truncate_prompt_tokens
169+
# Number of characters to hold back for stop string evaluation
170+
# until sequence is finished.
171+
if self.stop and not include_stop_str_in_output:
172+
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
173+
else:
174+
self.output_text_buffer_length = 0
175+
169176
self._verify_args()
170177
if self.use_beam_search:
171178
self._verify_beam_search()
@@ -226,6 +233,8 @@ def _verify_args(self) -> None:
226233
and self.truncate_prompt_tokens < 1):
227234
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
228235
f"got {self.truncate_prompt_tokens}")
236+
if any(not stop_str for stop_str in self.stop):
237+
raise ValueError("stop cannot contain an empty string.")
229238
if self.stop and not self.detokenize:
230239
raise ValueError(
231240
"stop strings are only supported when detokenize is True. "

vllm/sequence.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ def __init__(
235235
def lora_int_id(self) -> int:
236236
return self.lora_request.lora_int_id if self.lora_request else 0
237237

238+
def get_output_text_to_return(self, buffer_length: int):
239+
# We return the full output text if the sequence is finished.
240+
truncate = buffer_length and not self.is_finished()
241+
return self.output_text[:-buffer_length] if truncate else (
242+
self.output_text)
243+
238244
def hash_of_block(self, logical_idx: int) -> int:
239245
# TODO This can produce incorrect hash when block size > prompt size
240246

vllm/transformers_utils/detokenizer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace(
8787
prev_tokens.extend(next_iter_tokens)
8888

8989
def decode_sequence_inplace(self, seq: Sequence,
90-
prms: SamplingParams) -> None:
90+
prms: SamplingParams) -> int:
9191
"""Decodes the new token for a sequence. In-place operation.
9292
9393
Args:
9494
seq: The sequence to decode.
9595
prms: The sampling parameters used to generate the sequence.
96+
97+
Returns:
98+
The number of characters added to the output text.
9699
"""
97100
all_input_ids = seq.get_token_ids()
98101
token_id_generated_this_iteration = all_input_ids[-1]
@@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence,
151154
seq.read_offset = read_offset
152155
seq.output_text += new_decoded_token_text
153156

157+
return len(new_decoded_token_text)
158+
154159

155160
def _convert_tokens_to_string_with_added_encoders(
156161
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],

0 commit comments

Comments
 (0)