Skip to content

Commit 199fa94

Browse files
zifeitongAlvant
authored andcommitted
[Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (vllm-project#6954)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 783402f commit 199fa94

File tree

5 files changed

+92
-44
lines changed

5 files changed

+92
-44
lines changed

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import asyncio
2+
from contextlib import suppress
23
from dataclasses import dataclass
4+
from unittest.mock import MagicMock
35

6+
from vllm.engine.async_llm_engine import AsyncLLMEngine
7+
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
48
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
9+
from vllm.transformers_utils.tokenizer import get_tokenizer
510

611
MODEL_NAME = "openai-community/gpt2"
712
CHAT_TEMPLATE = "Dummy chat template for testing {}"
@@ -42,3 +47,37 @@ async def _async_serving_chat_init():
4247
def test_async_serving_chat_init():
4348
serving_completion = asyncio.run(_async_serving_chat_init())
4449
assert serving_completion.chat_template == CHAT_TEMPLATE
50+
51+
52+
def test_serving_chat_should_set_correct_max_tokens():
53+
mock_engine = MagicMock(spec=AsyncLLMEngine)
54+
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
55+
56+
serving_chat = OpenAIServingChat(mock_engine,
57+
MockModelConfig(),
58+
served_model_names=[MODEL_NAME],
59+
response_role="assistant",
60+
chat_template=CHAT_TEMPLATE,
61+
lora_modules=None,
62+
prompt_adapters=None,
63+
request_logger=None)
64+
req = ChatCompletionRequest(
65+
model=MODEL_NAME,
66+
messages=[{
67+
"role": "user",
68+
"content": "what is 1+1?"
69+
}],
70+
guided_decoding_backend="outlines",
71+
)
72+
73+
with suppress(Exception):
74+
asyncio.run(serving_chat.create_chat_completion(req))
75+
76+
# AsyncLLMEngine.generate(inputs, sampling_params, ...)
77+
assert mock_engine.generate.call_args.args[1].max_tokens == 93
78+
79+
req.max_tokens = 10
80+
with suppress(Exception):
81+
asyncio.run(serving_chat.create_chat_completion(req))
82+
83+
assert mock_engine.generate.call_args.args[1].max_tokens == 10

vllm/entrypoints/openai/protocol.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
1212
from vllm.entrypoints.openai.logits_processors import get_logits_processors
1313
from vllm.pooling_params import PoolingParams
14-
from vllm.sampling_params import SamplingParams
14+
from vllm.sampling_params import LogitsProcessor, SamplingParams
1515
from vllm.utils import random_uuid
1616

1717

@@ -215,15 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
215215

216216
# doc: end-chat-completion-extra-params
217217

218-
def to_sampling_params(self,
219-
tokenizer: PreTrainedTokenizer) -> SamplingParams:
220-
# We now allow logprobs being true without top_logrobs.
218+
def to_sampling_params(
219+
self, tokenizer: PreTrainedTokenizer,
220+
guided_decode_logits_processor: Optional[LogitsProcessor],
221+
default_max_tokens: int) -> SamplingParams:
222+
max_tokens = self.max_tokens
223+
if max_tokens is None:
224+
max_tokens = default_max_tokens
221225

226+
# We now allow logprobs being true without top_logrobs.
222227
logits_processors = get_logits_processors(
223228
logit_bias=self.logit_bias,
224229
allowed_token_ids=None,
225230
tokenizer=tokenizer,
226231
)
232+
if guided_decode_logits_processor:
233+
logits_processors.append(guided_decode_logits_processor)
227234

228235
return SamplingParams(
229236
n=self.n,
@@ -241,7 +248,7 @@ def to_sampling_params(self,
241248
logprobs=self.top_logprobs if self.logprobs else None,
242249
prompt_logprobs=self.top_logprobs if self.echo else None,
243250
ignore_eos=self.ignore_eos,
244-
max_tokens=self.max_tokens,
251+
max_tokens=max_tokens,
245252
min_tokens=self.min_tokens,
246253
use_beam_search=self.use_beam_search,
247254
early_stopping=self.early_stopping,
@@ -395,14 +402,23 @@ class CompletionRequest(OpenAIBaseModel):
395402

396403
# doc: end-completion-extra-params
397404

398-
def to_sampling_params(self, tokenizer: PreTrainedTokenizer):
405+
def to_sampling_params(
406+
self, tokenizer: PreTrainedTokenizer,
407+
guided_decode_logits_processor: Optional[LogitsProcessor],
408+
default_max_tokens: int) -> SamplingParams:
409+
max_tokens = self.max_tokens
410+
if max_tokens is None:
411+
max_tokens = default_max_tokens
412+
399413
echo_without_generation = self.echo and self.max_tokens == 0
400414

401415
logits_processors = get_logits_processors(
402416
logit_bias=self.logit_bias,
403417
allowed_token_ids=self.allowed_token_ids,
404418
tokenizer=tokenizer,
405419
)
420+
if guided_decode_logits_processor:
421+
logits_processors.append(guided_decode_logits_processor)
406422

407423
return SamplingParams(
408424
n=self.n,
@@ -419,7 +435,7 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer):
419435
stop_token_ids=self.stop_token_ids,
420436
logprobs=self.logprobs,
421437
ignore_eos=self.ignore_eos,
422-
max_tokens=self.max_tokens if not echo_without_generation else 1,
438+
max_tokens=max_tokens if not echo_without_generation else 1,
423439
min_tokens=self.min_tokens,
424440
use_beam_search=self.use_beam_search,
425441
early_stopping=self.early_stopping,

vllm/entrypoints/openai/serving_chat.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
PromptAdapterPath)
2626
from vllm.inputs import PromptInputs
2727
from vllm.logger import init_logger
28-
from vllm.model_executor.guided_decoding import (
29-
get_guided_decoding_logits_processor)
3028
from vllm.multimodal import MultiModalDataDict
3129
from vllm.outputs import RequestOutput
3230
from vllm.sequence import Logprob
@@ -134,28 +132,23 @@ async def create_chat_completion(
134132

135133
request_id = f"chat-{random_uuid()}"
136134
try:
137-
sampling_params = request.to_sampling_params(tokenizer)
138-
decoding_config = await self.engine.get_decoding_config()
139-
guided_decoding_backend = request.guided_decoding_backend \
140-
or decoding_config.guided_decoding_backend
141135
guided_decode_logits_processor = (
142-
await
143-
get_guided_decoding_logits_processor(guided_decoding_backend,
144-
request, tokenizer))
145-
if guided_decode_logits_processor:
146-
if sampling_params.logits_processors is None:
147-
sampling_params.logits_processors = []
148-
sampling_params.logits_processors.append(
149-
guided_decode_logits_processor)
136+
await self._guided_decode_logits_processor(request, tokenizer))
150137

151138
prompt_inputs = self._tokenize_prompt_input(
152139
request,
153140
tokenizer,
154141
prompt,
155-
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
142+
truncate_prompt_tokens=request.truncate_prompt_tokens,
156143
add_special_tokens=request.add_special_tokens,
157144
)
158145

146+
sampling_params = request.to_sampling_params(
147+
tokenizer,
148+
guided_decode_logits_processor,
149+
default_max_tokens=self.max_model_len -
150+
len(prompt_inputs["prompt_token_ids"]))
151+
159152
self._log_inputs(request_id,
160153
prompt_inputs,
161154
params=sampling_params,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
OpenAIServing,
2525
PromptAdapterPath)
2626
from vllm.logger import init_logger
27-
from vllm.model_executor.guided_decoding import (
28-
get_guided_decoding_logits_processor)
2927
from vllm.outputs import RequestOutput
3028
from vllm.sequence import Logprob
3129
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
@@ -95,31 +93,24 @@ async def create_completion(self, request: CompletionRequest,
9593

9694
tokenizer = await self.engine.get_tokenizer(lora_request)
9795

98-
sampling_params = request.to_sampling_params(tokenizer)
99-
decoding_config = await self.engine.get_decoding_config()
100-
guided_decoding_backend = request.guided_decoding_backend \
101-
or decoding_config.guided_decoding_backend
102-
guided_decode_logit_processor = (
103-
await
104-
get_guided_decoding_logits_processor(guided_decoding_backend,
105-
request, tokenizer))
106-
if guided_decode_logit_processor is not None:
107-
if sampling_params.logits_processors is None:
108-
sampling_params.logits_processors = []
109-
sampling_params.logits_processors.append(
110-
guided_decode_logit_processor)
111-
96+
guided_decode_logits_processor = (
97+
await self._guided_decode_logits_processor(request, tokenizer))
11298
prompts = list(
11399
self._tokenize_prompt_input_or_inputs(
114100
request,
115101
tokenizer,
116102
request.prompt,
117-
truncate_prompt_tokens=sampling_params.
118-
truncate_prompt_tokens,
103+
truncate_prompt_tokens=request.truncate_prompt_tokens,
119104
add_special_tokens=request.add_special_tokens,
120105
))
121106

122107
for i, prompt_inputs in enumerate(prompts):
108+
sampling_params = request.to_sampling_params(
109+
tokenizer,
110+
guided_decode_logits_processor,
111+
default_max_tokens=self.max_model_len -
112+
len(prompt_inputs["prompt_token_ids"]))
113+
123114
request_id_item = f"{request_id}-{i}"
124115

125116
self._log_inputs(request_id_item,

vllm/entrypoints/openai/serving_engine.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
from vllm.inputs import parse_and_batch_prompt
2626
from vllm.logger import init_logger
2727
from vllm.lora.request import LoRARequest
28+
from vllm.model_executor.guided_decoding import (
29+
get_guided_decoding_logits_processor)
2830
from vllm.pooling_params import PoolingParams
2931
from vllm.prompt_adapter.request import PromptAdapterRequest
30-
from vllm.sampling_params import SamplingParams
32+
from vllm.sampling_params import LogitsProcessor, SamplingParams
3133
from vllm.sequence import Logprob
3234
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
3335

@@ -150,6 +152,15 @@ def create_streaming_error_response(
150152
})
151153
return json_str
152154

155+
async def _guided_decode_logits_processor(
156+
self, request: Union[ChatCompletionRequest, CompletionRequest],
157+
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
158+
decoding_config = await self.engine.get_decoding_config()
159+
guided_decoding_backend = request.guided_decoding_backend \
160+
or decoding_config.guided_decoding_backend
161+
return await get_guided_decoding_logits_processor(
162+
guided_decoding_backend, request, tokenizer)
163+
153164
async def _check_model(
154165
self,
155166
request: AnyRequest,
@@ -254,9 +265,7 @@ def _validate_input(
254265
f"{self.max_model_len} tokens. However, you requested "
255266
f"{token_num} tokens in the messages, "
256267
f"Please reduce the length of the messages.")
257-
request.max_tokens = self.max_model_len - token_num
258-
259-
if token_num + request.max_tokens > self.max_model_len:
268+
elif token_num + request.max_tokens > self.max_model_len:
260269
raise ValueError(
261270
f"This model's maximum context length is "
262271
f"{self.max_model_len} tokens. However, you requested "

0 commit comments

Comments
 (0)