Skip to content

Commit c54269d

Browse files
authored
[Frontend] Add tokenize/detokenize endpoints (#5054)
1 parent 5bfd1bb commit c54269d

File tree

5 files changed

+143
-6
lines changed

5 files changed

+143
-6
lines changed

tests/entrypoints/test_openai_server.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# using Ray for overall ease of process management, parallel requests,
1010
# and debugging.
1111
import ray
12+
import requests
1213
import torch
1314
# downloading lora to test lora requests
1415
from huggingface_hub import snapshot_download
@@ -1366,5 +1367,53 @@ async def test_long_seed(client: openai.AsyncOpenAI):
13661367
or "less_than_equal" in exc_info.value.message)
13671368

13681369

1370+
@pytest.mark.asyncio
1371+
@pytest.mark.parametrize(
1372+
"model_name",
1373+
[MODEL_NAME],
1374+
)
1375+
async def test_tokenize(server, client: openai.AsyncOpenAI, model_name: str):
1376+
base_url = str(client.base_url)[:-3]
1377+
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
1378+
1379+
for add_special in [False, True]:
1380+
prompt = "This is a test prompt."
1381+
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
1382+
1383+
response = requests.post(base_url + "/tokenize",
1384+
json={
1385+
"add_special_tokens": add_special,
1386+
"model": model_name,
1387+
"prompt": prompt
1388+
})
1389+
response.raise_for_status()
1390+
assert response.json() == {
1391+
"tokens": tokens,
1392+
"count": len(tokens),
1393+
"max_model_len": 8192
1394+
}
1395+
1396+
1397+
@pytest.mark.asyncio
1398+
@pytest.mark.parametrize(
1399+
"model_name",
1400+
[MODEL_NAME],
1401+
)
1402+
async def test_detokenize(server, client: openai.AsyncOpenAI, model_name: str):
1403+
base_url = str(client.base_url)[:-3]
1404+
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
1405+
1406+
prompt = "This is a test prompt."
1407+
tokens = tokenizer.encode(prompt, add_special_tokens=False)
1408+
1409+
response = requests.post(base_url + "detokenize",
1410+
json={
1411+
"model": model_name,
1412+
"tokens": tokens
1413+
})
1414+
response.raise_for_status()
1415+
assert response.json() == {"prompt": prompt}
1416+
1417+
13691418
if __name__ == "__main__":
13701419
pytest.main([__file__])

vllm/entrypoints/openai/api_server.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,17 @@
1919
from vllm.engine.arg_utils import AsyncEngineArgs
2020
from vllm.engine.async_llm_engine import AsyncLLMEngine
2121
from vllm.entrypoints.openai.cli_args import make_arg_parser
22+
# yapf conflicts with isort for this block
23+
# yapf: disable
2224
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
2325
ChatCompletionResponse,
2426
CompletionRequest,
25-
EmbeddingRequest, ErrorResponse)
27+
DetokenizeRequest,
28+
DetokenizeResponse,
29+
EmbeddingRequest, ErrorResponse,
30+
TokenizeRequest,
31+
TokenizeResponse)
32+
# yapf: enable
2633
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
2734
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
2835
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@@ -85,6 +92,28 @@ async def health() -> Response:
8592
return Response(status_code=200)
8693

8794

95+
@app.post("/tokenize")
96+
async def tokenize(request: TokenizeRequest):
97+
generator = await openai_serving_completion.create_tokenize(request)
98+
if isinstance(generator, ErrorResponse):
99+
return JSONResponse(content=generator.model_dump(),
100+
status_code=generator.code)
101+
else:
102+
assert isinstance(generator, TokenizeResponse)
103+
return JSONResponse(content=generator.model_dump())
104+
105+
106+
@app.post("/detokenize")
107+
async def detokenize(request: DetokenizeRequest):
108+
generator = await openai_serving_completion.create_detokenize(request)
109+
if isinstance(generator, ErrorResponse):
110+
return JSONResponse(content=generator.model_dump(),
111+
status_code=generator.code)
112+
else:
113+
assert isinstance(generator, DetokenizeResponse)
114+
return JSONResponse(content=generator.model_dump())
115+
116+
88117
@app.get("/v1/models")
89118
async def show_available_models():
90119
models = await openai_serving_chat.show_available_models()

vllm/entrypoints/openai/protocol.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,24 @@ class BatchRequestOutput(OpenAIBaseModel):
699699
# For requests that failed with a non-HTTP error, this will contain more
700700
# information on the cause of the failure.
701701
error: Optional[Any]
702+
703+
704+
class TokenizeRequest(OpenAIBaseModel):
705+
model: str
706+
prompt: str
707+
add_special_tokens: bool = Field(default=True)
708+
709+
710+
class TokenizeResponse(OpenAIBaseModel):
711+
tokens: List[int]
712+
count: int
713+
max_model_len: int
714+
715+
716+
class DetokenizeRequest(OpenAIBaseModel):
717+
model: str
718+
tokens: List[int]
719+
720+
721+
class DetokenizeResponse(OpenAIBaseModel):
722+
prompt: str

vllm/entrypoints/openai/serving_completion.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
CompletionResponseChoice,
1717
CompletionResponseStreamChoice,
1818
CompletionStreamResponse,
19-
UsageInfo)
19+
DetokenizeRequest,
20+
DetokenizeResponse,
21+
TokenizeRequest,
22+
TokenizeResponse, UsageInfo)
23+
# yapf: enable
2024
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
2125
OpenAIServing)
2226
from vllm.logger import init_logger
@@ -442,3 +446,29 @@ def _create_completion_logprobs(
442446
tokens=out_tokens,
443447
top_logprobs=out_top_logprobs,
444448
)
449+
450+
async def create_tokenize(self,
451+
request: TokenizeRequest) -> TokenizeResponse:
452+
error_check_ret = await self._check_model(request)
453+
if error_check_ret is not None:
454+
return error_check_ret
455+
456+
(input_ids, input_text) = self._validate_prompt_and_tokenize(
457+
request,
458+
prompt=request.prompt,
459+
add_special_tokens=request.add_special_tokens)
460+
461+
return TokenizeResponse(tokens=input_ids,
462+
count=len(input_ids),
463+
max_model_len=self.max_model_len)
464+
465+
async def create_detokenize(
466+
self, request: DetokenizeRequest) -> DetokenizeResponse:
467+
error_check_ret = await self._check_model(request)
468+
if error_check_ret is not None:
469+
return error_check_ret
470+
471+
(input_ids, input_text) = self._validate_prompt_and_tokenize(
472+
request, prompt_ids=request.tokens)
473+
474+
return DetokenizeResponse(prompt=input_text)

vllm/entrypoints/openai/serving_engine.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from vllm.engine.async_llm_engine import AsyncLLMEngine
1111
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1212
CompletionRequest,
13+
DetokenizeRequest,
1314
EmbeddingRequest, ErrorResponse,
1415
ModelCard, ModelList,
15-
ModelPermission)
16+
ModelPermission, TokenizeRequest)
1617
from vllm.logger import init_logger
1718
from vllm.lora.request import LoRARequest
1819
from vllm.sequence import Logprob
@@ -99,8 +100,9 @@ def create_streaming_error_response(
99100
return json_str
100101

101102
async def _check_model(
102-
self, request: Union[CompletionRequest, ChatCompletionRequest,
103-
EmbeddingRequest]
103+
self, request: Union[ChatCompletionRequest, CompletionRequest,
104+
DetokenizeRequest, EmbeddingRequest,
105+
TokenizeRequest]
104106
) -> Optional[ErrorResponse]:
105107
if request.model in self.served_model_names:
106108
return None
@@ -126,7 +128,8 @@ def _maybe_get_lora(
126128
def _validate_prompt_and_tokenize(
127129
self,
128130
request: Union[ChatCompletionRequest, CompletionRequest,
129-
EmbeddingRequest],
131+
DetokenizeRequest, EmbeddingRequest,
132+
TokenizeRequest],
130133
prompt: Optional[str] = None,
131134
prompt_ids: Optional[List[int]] = None,
132135
truncate_prompt_tokens: Optional[Annotated[int,
@@ -174,6 +177,11 @@ def _validate_prompt_and_tokenize(
174177
f"generation. Please reduce the length of the input.", )
175178
return input_ids, input_text
176179

180+
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
181+
# and does not require model context length validation
182+
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
183+
return input_ids, input_text
184+
177185
if request.max_tokens is None:
178186
if token_num >= self.max_model_len:
179187
raise ValueError(

0 commit comments

Comments
 (0)