Skip to content

Commit cb621ef

Browse files
njhillkdamaszk
authored andcommitted
[BugFix] Work-around incremental detokenization edge case error (vllm-project#19449)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 2b8dde9 commit cb621ef

File tree

2 files changed

+113
-6
lines changed

2 files changed

+113
-6
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from transformers import AutoTokenizer
5+
6+
from vllm.sampling_params import SamplingParams
7+
from vllm.v1.engine import EngineCoreRequest
8+
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
9+
10+
# ruff: noqa: E501
11+
12+
13+
def test_fast_inc_detok_invalid_utf8_err_case():
14+
"""
15+
Test edge case where tokenizer can produce non-monotonic,
16+
invalid UTF-8 output, which breaks the internal state of
17+
tokenizers' DecodeStream.
18+
See https://github.com/vllm-project/vllm/issues/17448.
19+
20+
Thanks to reproducer from @fpaupier:
21+
https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3.
22+
"""
23+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
24+
25+
# Create a test request
26+
prompt_token_ids = [107, 4606, 236787, 107]
27+
params = SamplingParams(skip_special_tokens=True)
28+
request = EngineCoreRequest(
29+
"test",
30+
prompt_token_ids,
31+
None,
32+
None,
33+
None,
34+
params,
35+
None,
36+
0.0,
37+
None,
38+
cache_salt=None,
39+
data_parallel_rank=None,
40+
)
41+
42+
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
43+
44+
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
45+
"Should use FastIncrementalDetokenizer by default"
46+
47+
# Process tokens incrementally
48+
test_tokens = [
49+
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
50+
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
51+
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
52+
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
53+
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
54+
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
55+
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
56+
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
57+
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
58+
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
59+
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
60+
2555, 513, 236789, 602, 31118, 569
61+
]
62+
63+
output = ""
64+
for i, token_id in enumerate(test_tokens):
65+
detokenizer.update([token_id], False)
66+
67+
finished = i == len(test_tokens) - 1
68+
output += detokenizer.get_next_output_text(finished, delta=True)
69+
70+
71+
# fmt: off
72+
assert output == r'''[
73+
{
74+
"source": "Résultats",
75+
"source_type": "CONCEPT",
76+
"source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise",
77+
"target": "Israël",
78+
"target_type": "ORGANIZATION",
79+
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
80+
"relationship": "Obtention d'un niveau de'''

vllm/v1/engine/detokenizer.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616

1717
logger = init_logger(__name__)
1818

19+
# Only tokenizers >= 0.21.1 supports DecodeStream used for
20+
# FastIncrementalDetokenizer.
21+
USE_FAST_DETOKENIZER = version.parse(
22+
tokenizers.__version__) >= version.parse("0.21.1")
23+
24+
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
25+
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
26+
1927

2028
class IncrementalDetokenizer:
2129

@@ -45,10 +53,9 @@ def from_new_request(
4553
# No tokenizer => skipping detokenization.
4654
return IncrementalDetokenizer()
4755

48-
if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse(
49-
tokenizers.__version__) >= version.parse("0.21.1")):
56+
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
57+
PreTrainedTokenizerFast):
5058
# Fast tokenizer => use tokenizers library DecodeStream.
51-
# And only tokenizers >= 0.21.1 supports Fast Detokenizer.
5259
return FastIncrementalDetokenizer(tokenizer, request)
5360

5461
# Fall back to slow python-based incremental detokenization.
@@ -156,8 +163,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast,
156163
super().__init__(request)
157164

158165
sampling_params = request.sampling_params
166+
167+
self.request_id = request.request_id
168+
self.skip_special_tokens = sampling_params.skip_special_tokens
159169
self.stream = DecodeStream(
160-
skip_special_tokens=sampling_params.skip_special_tokens)
170+
skip_special_tokens=self.skip_special_tokens)
161171

162172
self.tokenizer: Tokenizer = tokenizer._tokenizer
163173

@@ -173,7 +183,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast,
173183

174184
# Prime the stream.
175185
for tid in prompt_suffix:
176-
self.stream.step(self.tokenizer, tid)
186+
self._protected_step(tid)
177187

178188
self.spaces_between_special_tokens = (
179189
sampling_params.skip_special_tokens
@@ -198,7 +208,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast,
198208
self.spaces_between_special_tokens = True
199209

200210
def decode_next(self, next_token_id: int) -> str:
201-
token = self.stream.step(self.tokenizer, next_token_id)
211+
token = self._protected_step(next_token_id)
202212

203213
if not self.spaces_between_special_tokens:
204214
special_token = self.added_token_ids.get(next_token_id)
@@ -210,6 +220,23 @@ def decode_next(self, next_token_id: int) -> str:
210220

211221
return token or ""
212222

223+
def _protected_step(self, next_token_id: int) -> Optional[str]:
224+
try:
225+
token = self.stream.step(self.tokenizer, next_token_id)
226+
except Exception as e:
227+
if str(e) != INVALID_PREFIX_ERR_MSG:
228+
raise e
229+
# Recover from edge case where tokenizer can produce non-monotonic,
230+
# invalid UTF-8 output, which breaks the internal state of
231+
# tokenizers' DecodeStream.
232+
# See https://github.com/vllm-project/vllm/issues/17448.
233+
logger.warning(
234+
"Encountered invalid prefix detokenization error"
235+
" for request %s, resetting decode stream.", self.request_id)
236+
self.stream = DecodeStream(self.skip_special_tokens)
237+
token = self.stream.step(self.tokenizer, next_token_id)
238+
return token
239+
213240

214241
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
215242

0 commit comments

Comments
 (0)