Skip to content

Commit a324c13

Browse files
committed
[Bugfix] Fix score api for missing max_model_len validation
Signed-off-by: Wallas Santos <wallashss@ibm.com>
1 parent 2e0e017 commit a324c13

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

tests/entrypoints/openai/test_score.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def server():
2020

2121
@pytest.mark.asyncio
2222
@pytest.mark.parametrize("model_name", [MODEL_NAME])
23-
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
24-
model_name: str):
23+
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
2524
text_1 = "What is the capital of France?"
2625
text_2 = [
2726
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
@@ -45,8 +44,7 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
4544

4645
@pytest.mark.asyncio
4746
@pytest.mark.parametrize("model_name", [MODEL_NAME])
48-
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
49-
model_name: str):
47+
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
5048
text_1 = [
5149
"What is the capital of the United States?",
5250
"What is the capital of France?"
@@ -73,8 +71,7 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
7371

7472
@pytest.mark.asyncio
7573
@pytest.mark.parametrize("model_name", [MODEL_NAME])
76-
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
77-
model_name: str):
74+
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
7875
text_1 = "What is the capital of France?"
7976
text_2 = "The capital of France is Paris."
8077

@@ -91,3 +88,41 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
9188
assert score.data is not None
9289
assert len(score.data) == 1
9390
assert score.data[0].score >= 0.9
91+
92+
93+
@pytest.mark.asyncio
94+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
95+
def test_score_max_model_len(model_name: str):
96+
97+
args = ["--enforce-eager", "--max-model-len", "5"]
98+
99+
with RemoteOpenAIServer(model_name, args) as remote_server:
100+
101+
text_1 = "What is the capital of France?"
102+
text_2 = [
103+
"The capital of Brazil is Brasilia.",
104+
"The capital of France is Paris."
105+
]
106+
107+
score_response = requests.post(remote_server.url_for("score"),
108+
json={
109+
"model": model_name,
110+
"text_1": text_1,
111+
"text_2": text_2,
112+
})
113+
assert score_response.status_code == 400
114+
# Assert just a small fragments of the response
115+
assert "Please reduce the length of the input." in \
116+
score_response.text
117+
118+
# Test truncation
119+
score_response = requests.post(remote_server.url_for("score"),
120+
json={
121+
"model": model_name,
122+
"text_1": text_1,
123+
"text_2": text_2,
124+
"truncate_prompt_tokens": 10
125+
})
126+
assert score_response.status_code == 400
127+
assert "Please, select a smaller truncation size." in \
128+
score_response.text

vllm/entrypoints/openai/serving_score.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ async def create_score(
101101
if not self.model_config.is_cross_encoder:
102102
raise ValueError("Model is not cross encoder.")
103103

104+
if truncate_prompt_tokens is not None and \
105+
truncate_prompt_tokens > self.max_model_len:
106+
raise ValueError(
107+
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
108+
f"is greater than max_model_len ({self.max_model_len})."
109+
f" Please, select a smaller truncation size.")
110+
104111
except ValueError as e:
105112
logger.exception("Error in preprocessing prompt inputs")
106113
return self.create_error_response(str(e))
@@ -123,8 +130,19 @@ async def create_score(
123130
prompt_inputs = await tokenize_async(text=q,
124131
text_pair=t,
125132
**tokenization_kwargs)
133+
134+
input_ids = prompt_inputs["input_ids"]
135+
token_num = len(input_ids)
136+
if len(input_ids) > self.max_model_len:
137+
err_msg = (
138+
f"This model's maximum context length is "
139+
f"{self.max_model_len} tokens. However, you requested "
140+
f"{token_num} tokens in the input for score. "
141+
f"Please reduce the length of the input.")
142+
logger.error(err_msg)
143+
return self.create_error_response(err_msg)
126144
engine_prompt = TokensPrompt(
127-
prompt_token_ids=prompt_inputs["input_ids"],
145+
prompt_token_ids=input_ids,
128146
token_type_ids=prompt_inputs.get("token_type_ids"))
129147

130148
request_prompts.append(request_prompt)

0 commit comments

Comments
 (0)