Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to completion API to truncate prompt tokens #3144

Merged
merged 13 commits into from
Apr 5, 2024
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, List, Literal, Optional, Union

import torch
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, conint, model_validator

from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
Expand Down Expand Up @@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -309,6 +310,7 @@ def logit_bias_logits_processor(
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)

@model_validator(mode="before")
Expand Down
10 changes: 8 additions & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,16 @@ async def create_completion(self, request: CompletionRequest,
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)

generators.append(
self.engine.generate(prompt,
Expand Down
22 changes: 18 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from http import HTTPStatus
from typing import Dict, List, Optional, Union

from pydantic import conint

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse,
Expand Down Expand Up @@ -66,7 +68,8 @@ async def _post_init(self):
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
tdoublep marked this conversation as resolved.
Show resolved Hide resolved

if len(self.tokenizer) != engine_model_config.get_vocab_size():
logger.warning(
Expand Down Expand Up @@ -172,15 +175,26 @@ def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")

input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
if prompt_ids is None:
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids

token_num = len(input_ids)

if request.max_tokens is None:
Expand Down
13 changes: 12 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, List, Optional, Union

import torch
from pydantic import conint

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -93,6 +94,9 @@ class SamplingParams:
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
"""

def __init__(
Expand Down Expand Up @@ -121,6 +125,7 @@ def __init__(
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand Down Expand Up @@ -154,6 +159,7 @@ def __init__(
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
Expand Down Expand Up @@ -210,6 +216,10 @@ def _verify_args(self) -> None:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")

def _verify_beam_search(self) -> None:
if self.best_of == 1:
Expand Down Expand Up @@ -290,4 +300,5 @@ def __repr__(self) -> str:
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
Loading