Skip to content

[V1] Do not detokenize if sampling param detokenize is False #14224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion tests/v1/sample/test_sampling_params_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

@pytest.fixture(scope="module")
def model() -> LLM:
return LLM(MODEL, enforce_eager=True)
# Disable prefix caching so that we can test prompt logprobs.
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
# is merged
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)


def test_n_gt_1(model):
Expand Down Expand Up @@ -79,9 +82,33 @@ def test_stop_token_ids(model):

stop_token_ids = [stop_token_id_0, stop_token_id_1]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
output = model.generate(PROMPT, params)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0


def test_detokenize_false(model):
"""Check that detokenize=False option works."""

output = model.generate(PROMPT, SamplingParams(detokenize=False))
assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0

output = model.generate(
PROMPT, SamplingParams(detokenize=False, logprobs=3,
prompt_logprobs=3))
assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0

prompt_logprobs = output[0].prompt_logprobs
sampled_logprobs = output[0].outputs[0].logprobs
assert len(prompt_logprobs) > 1
assert len(sampled_logprobs) > 1
for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
for logprobs in all_logprobs:
assert 3 <= len(logprobs) <= 4
assert all(lp.decoded_token is None for lp in logprobs.values())


def test_bad_words(model):
"""Check that we respect bad words."""

Expand Down
40 changes: 24 additions & 16 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from vllm.engine.output_processor.stop_checker import StopChecker
Expand All @@ -16,41 +16,46 @@
class IncrementalDetokenizer:

# Generation data
output_text: str
tokens: list[str]
token_ids: list[int]
prompt_len: int
output_text: str = ""
tokens: list[str] = field(default_factory=list)
prompt_len: int = 0

# Stop strings
stop: list[str]
include_stop_str_in_output: bool
stop: list[str] = field(default_factory=list)
include_stop_str_in_output: bool = False

# Metadata for incremental detokenization
prefix_offset: int
read_offset: int
prefix_offset: int = 0
read_offset: int = 0

# Parameters for detokenization
skip_special_tokens: bool
spaces_between_special_tokens: bool
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True

# Tokenizer for this request
tokenizer: AnyTokenizer
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer] = None

# Accounting for stop string buffering
stop_buffer_length: int
stop_buffer_length: int = 0
_last_output_text_offset: int = 0

@property
def output_token_ids(self) -> list[int]:
return self.token_ids[self.prompt_len:]
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])

@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":

if tokenizer is None:
return cls(token_ids=[])

tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
Expand All @@ -66,7 +71,6 @@ def from_new_request(
stop_buffer_length = 0

return cls(
output_text="",
tokens=tokens,
# Detokenizer mutates this list, so need a unique copy.
# NOTE(Nick): could we take ownership of it though?
Expand All @@ -93,6 +97,10 @@ def update(self, new_token_ids: list[int]) -> Optional[str]:
Return matched stop string or None.
"""

if self.tokenizer is None:
self.token_ids.extend(new_token_ids)
return None

# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
Expand Down
25 changes: 15 additions & 10 deletions vllm/v1/engine/logprobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional

Expand All @@ -13,12 +14,15 @@

logger = init_logger(__name__)

NONES = itertools.repeat(None)


@dataclass
class LogprobsProcessor:

# Tokenizer for this request
tokenizer: AnyTokenizer
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer]

# Logprobs for this request
logprobs: Optional[SampleLogprobs]
Expand All @@ -30,7 +34,7 @@ class LogprobsProcessor:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
Expand Down Expand Up @@ -66,8 +70,8 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
token_ids_lst):

# Detokenize (non-incrementally).
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer, token_ids)
decoded_tokens = NONES if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer, token_ids))

# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
Expand Down Expand Up @@ -103,9 +107,9 @@ def _update_prompt_logprobs(

# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer,
token_ids.flatten().tolist())
decoded_tokens = None if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer,
token_ids.flatten().tolist()))

# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
Expand All @@ -121,7 +125,8 @@ def _update_prompt_logprobs(
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
decoded_tokens_for_pos = NONES \
if decoded_tokens is None else decoded_tokens[offset:offset_end]

# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
Expand Down Expand Up @@ -153,7 +158,7 @@ def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: list[str],
decoded_tokens: Iterable[Optional[str]],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def from_new_request(
queue: Optional[asyncio.Queue[RequestOutput]],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
tokenizer = None
return cls(
request_id=request.request_id,
parent_req=parent_req,
Expand Down