Skip to content

Commit 32985be

Browse files
authored
[Frontend] Allow return_tokens_as_token_ids to be passed as a request param (#14066)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
1 parent dae9ec4 commit 32985be

File tree

4 files changed

+64
-25
lines changed

4 files changed

+64
-25
lines changed

tests/entrypoints/openai/test_return_tokens_as_ids.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,28 @@
1717

1818

1919
@pytest.fixture(scope="module")
20-
def server_with_return_tokens_as_token_ids_flag(
21-
default_server_args): # noqa: F811
22-
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
23-
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
24-
yield remote_server
20+
def server_fixture(request, default_server_args): # noqa: F811
21+
use_server_flag = request.param
22+
if use_server_flag:
23+
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
24+
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
25+
yield (remote_server, True)
26+
else:
27+
with RemoteOpenAIServer(MODEL_NAME,
28+
default_server_args) as remote_server:
29+
yield (remote_server, False)
2530

2631

2732
@pytest.mark.asyncio
33+
@pytest.mark.parametrize("server_fixture", [True, False], indirect=True)
2834
async def test_completion_return_tokens_as_token_ids_completion(
29-
server_with_return_tokens_as_token_ids_flag):
30-
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
31-
) as client:
35+
server_fixture):
36+
server, use_server_flag = server_fixture
37+
request_args = {}
38+
if not use_server_flag:
39+
request_args["return_tokens_as_token_ids"] = True
40+
41+
async with server.get_async_client() as client:
3242

3343
completion = await client.completions.create(
3444
model=MODEL_NAME,
@@ -39,7 +49,8 @@ async def test_completion_return_tokens_as_token_ids_completion(
3949
echo=True,
4050
temperature=0,
4151
max_tokens=10,
42-
logprobs=1)
52+
logprobs=1,
53+
extra_body=request_args)
4354

4455
text = completion.choices[0].text
4556
token_strs = completion.choices[0].logprobs.tokens
@@ -60,10 +71,14 @@ async def test_completion_return_tokens_as_token_ids_completion(
6071

6172

6273
@pytest.mark.asyncio
63-
async def test_chat_return_tokens_as_token_ids_completion(
64-
server_with_return_tokens_as_token_ids_flag):
65-
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
66-
) as client:
74+
@pytest.mark.parametrize("server_fixture", [True, False], indirect=True)
75+
async def test_chat_return_tokens_as_token_ids_completion(server_fixture):
76+
server, use_server_flag = server_fixture
77+
request_args = {}
78+
if not use_server_flag:
79+
request_args["return_tokens_as_token_ids"] = True
80+
81+
async with server.get_async_client() as client:
6782
response = await client.chat.completions.create(
6883
model=MODEL_NAME,
6984
# Include Unicode characters to test for dividing a single
@@ -78,7 +93,8 @@ async def test_chat_return_tokens_as_token_ids_completion(
7893
}],
7994
temperature=0,
8095
max_tokens=8,
81-
logprobs=True)
96+
logprobs=True,
97+
extra_body=request_args)
8298

8399
text = response.choices[0].message.content
84100
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)

vllm/entrypoints/openai/protocol.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
369369
"arguments. For example: {'qualname': "
370370
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
371371
"{'param': 'value'}}."))
372+
return_tokens_as_token_ids: Optional[bool] = Field(
373+
default=None,
374+
description=(
375+
"If specified with 'logprobs', tokens are represented "
376+
" as strings of the form 'token_id:{token_id}' so that tokens "
377+
"that are not JSON-encodable can be identified."))
372378

373379
# doc: end-chat-completion-extra-params
374380

@@ -739,6 +745,12 @@ class CompletionRequest(OpenAIBaseModel):
739745
"arguments. For example: {'qualname': "
740746
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
741747
"{'param': 'value'}}."))
748+
return_tokens_as_token_ids: Optional[bool] = Field(
749+
default=None,
750+
description=(
751+
"If specified with 'logprobs', tokens are represented "
752+
" as strings of the form 'token_id:{token_id}' so that tokens "
753+
"that are not JSON-encodable can be identified."))
742754

743755
# doc: end-completion-extra-params
744756

vllm/entrypoints/openai/serving_chat.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,8 @@ async def chat_completion_stream_generator(
450450
top_logprobs=output.logprobs,
451451
tokenizer=tokenizer,
452452
num_output_top_logprobs=request.top_logprobs,
453+
return_as_token_id=request.
454+
return_tokens_as_token_ids,
453455
)
454456
else:
455457
logprobs = None
@@ -705,6 +707,7 @@ async def chat_completion_full_generator(
705707
top_logprobs=out_logprobs,
706708
num_output_top_logprobs=request.top_logprobs,
707709
tokenizer=tokenizer,
710+
return_as_token_id=request.return_tokens_as_token_ids,
708711
)
709712
else:
710713
logprobs = None
@@ -852,13 +855,14 @@ async def chat_completion_full_generator(
852855

853856
def _get_top_logprobs(
854857
self, logprobs: dict[int, Logprob], top_logprobs: Optional[int],
855-
tokenizer: AnyTokenizer) -> list[ChatCompletionLogProb]:
858+
tokenizer: AnyTokenizer,
859+
should_return_as_token_id: bool) -> list[ChatCompletionLogProb]:
856860
return [
857861
ChatCompletionLogProb(token=(token := self._get_decoded_token(
858862
p[1],
859863
p[0],
860864
tokenizer,
861-
return_as_token_id=self.return_tokens_as_token_ids)),
865+
return_as_token_id=should_return_as_token_id)),
862866
logprob=max(p[1].logprob, -9999.0),
863867
bytes=list(
864868
token.encode("utf-8", errors="replace")))
@@ -872,15 +876,18 @@ def _create_chat_logprobs(
872876
top_logprobs: GenericSequence[Optional[dict[int, Logprob]]],
873877
tokenizer: AnyTokenizer,
874878
num_output_top_logprobs: Optional[int] = None,
879+
return_as_token_id: Optional[bool] = None,
875880
) -> ChatCompletionLogProbs:
876881
"""Create OpenAI-style logprobs."""
877882
logprobs_content: list[ChatCompletionLogProbsContent] = []
878883

884+
should_return_as_token_id = return_as_token_id if \
885+
return_as_token_id is not None else self.return_tokens_as_token_ids
879886
for i, token_id in enumerate(token_ids):
880887
step_top_logprobs = top_logprobs[i]
881888
if step_top_logprobs is None:
882889
token = tokenizer.decode(token_id)
883-
if self.return_tokens_as_token_ids:
890+
if should_return_as_token_id:
884891
token = f"token_id:{token_id}"
885892

886893
logprobs_content.append(
@@ -898,16 +905,14 @@ def _create_chat_logprobs(
898905
step_token,
899906
token_id,
900907
tokenizer,
901-
self.return_tokens_as_token_ids,
908+
should_return_as_token_id,
902909
),
903910
logprob=max(step_token.logprob, -9999.0),
904911
bytes=None if step_decoded is None else list(
905912
step_decoded.encode("utf-8", errors="replace")),
906913
top_logprobs=self._get_top_logprobs(
907-
step_top_logprobs,
908-
num_output_top_logprobs,
909-
tokenizer,
910-
),
914+
step_top_logprobs, num_output_top_logprobs,
915+
tokenizer, should_return_as_token_id),
911916
))
912917

913918
return ChatCompletionLogProbs(content=logprobs_content)

vllm/entrypoints/openai/serving_completion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ async def completion_stream_generator(
316316
num_output_top_logprobs=request.logprobs,
317317
tokenizer=tokenizer,
318318
initial_text_offset=previous_text_lens[i],
319+
return_as_token_id=request.
320+
return_tokens_as_token_ids,
319321
)
320322
else:
321323
logprobs = None
@@ -436,6 +438,7 @@ def request_output_to_completion_response(
436438
top_logprobs=out_logprobs,
437439
tokenizer=tokenizer,
438440
num_output_top_logprobs=request.logprobs,
441+
return_as_token_id=request.return_tokens_as_token_ids,
439442
)
440443
else:
441444
logprobs = None
@@ -477,6 +480,7 @@ def _create_completion_logprobs(
477480
num_output_top_logprobs: int,
478481
tokenizer: AnyTokenizer,
479482
initial_text_offset: int = 0,
483+
return_as_token_id: Optional[bool] = None,
480484
) -> CompletionLogProbs:
481485
"""Create logprobs for OpenAI Completion API."""
482486
out_text_offset: list[int] = []
@@ -486,11 +490,13 @@ def _create_completion_logprobs(
486490

487491
last_token_len = 0
488492

493+
should_return_as_token_id = return_as_token_id if \
494+
return_as_token_id is not None else self.return_tokens_as_token_ids
489495
for i, token_id in enumerate(token_ids):
490496
step_top_logprobs = top_logprobs[i]
491497
if step_top_logprobs is None:
492498
token = tokenizer.decode(token_id)
493-
if self.return_tokens_as_token_ids:
499+
if should_return_as_token_id:
494500
token = f"token_id:{token_id}"
495501

496502
out_tokens.append(token)
@@ -503,7 +509,7 @@ def _create_completion_logprobs(
503509
step_token,
504510
token_id,
505511
tokenizer,
506-
return_as_token_id=self.return_tokens_as_token_ids,
512+
return_as_token_id=should_return_as_token_id,
507513
)
508514
token_logprob = max(step_token.logprob, -9999.0)
509515

@@ -520,7 +526,7 @@ def _create_completion_logprobs(
520526
self._get_decoded_token(top_lp[1],
521527
top_lp[0],
522528
tokenizer,
523-
return_as_token_id=self.return_tokens_as_token_ids):
529+
return_as_token_id=should_return_as_token_id):
524530
max(top_lp[1].logprob, -9999.0)
525531
for i, top_lp in enumerate(step_top_logprobs.items())
526532
if num_output_top_logprobs >= i

0 commit comments

Comments
 (0)