Skip to content

Commit

Permalink
Add option to completion API to truncate prompt tokens (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tdoublep authored Apr 5, 2024
1 parent 095f5c7 commit 80fb35b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, List, Literal, Optional, Union

import torch
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, conint, model_validator

from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
Expand Down Expand Up @@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -309,6 +310,7 @@ def logit_bias_logits_processor(
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)

@model_validator(mode="before")
Expand Down
10 changes: 8 additions & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,16 @@ async def create_completion(self, request: CompletionRequest,
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)

generators.append(
self.engine.generate(prompt,
Expand Down
22 changes: 18 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from http import HTTPStatus
from typing import Dict, List, Optional, Union

from pydantic import conint

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse,
Expand Down Expand Up @@ -66,7 +68,8 @@ async def _post_init(self):
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
Expand Down Expand Up @@ -164,15 +167,26 @@ def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")

input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
if prompt_ids is None:
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids

token_num = len(input_ids)

if request.max_tokens is None:
Expand Down
13 changes: 12 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, List, Optional, Union

import torch
from pydantic import conint

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -94,6 +95,9 @@ class SamplingParams:
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
"""

def __init__(
Expand Down Expand Up @@ -123,6 +127,7 @@ def __init__(
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand Down Expand Up @@ -160,6 +165,7 @@ def __init__(
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
Expand Down Expand Up @@ -216,6 +222,10 @@ def _verify_args(self) -> None:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
Expand Down Expand Up @@ -300,4 +310,5 @@ def __repr__(self) -> str:
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")

0 comments on commit 80fb35b

Please sign in to comment.