Skip to content

Commit 7eaf477

Browse files
committed
Adds truncate_prompt_tokens param for embeddings creation
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
1 parent 83caf35 commit 7eaf477

File tree

3 files changed

+76
-5
lines changed

3 files changed

+76
-5
lines changed

tests/entrypoints/openai/test_embedding.py

+61
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,64 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
144144
0].embedding
145145
assert responses_float.data[1].embedding == responses_default.data[
146146
1].embedding
147+
148+
149+
@pytest.mark.asyncio
150+
@pytest.mark.parametrize(
151+
"model_name",
152+
[EMBEDDING_MODEL_NAME],
153+
)
154+
async def test_single_embedding_truncation(
155+
embedding_client: openai.AsyncOpenAI, model_name: str):
156+
input_texts = [
157+
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
158+
]
159+
160+
# test single embedding
161+
embeddings = await embedding_client.embeddings.create(
162+
model=model_name,
163+
input=input_texts,
164+
extra_body={"truncate_prompt_tokens": 10})
165+
assert embeddings.id is not None
166+
assert len(embeddings.data) == 1
167+
assert len(embeddings.data[0].embedding) == 4096
168+
assert embeddings.usage.completion_tokens == 0
169+
assert embeddings.usage.prompt_tokens == 10
170+
assert embeddings.usage.total_tokens == 10
171+
172+
input_tokens = [
173+
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
174+
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
175+
]
176+
embeddings = await embedding_client.embeddings.create(
177+
model=model_name,
178+
input=input_tokens,
179+
extra_body={"truncate_prompt_tokens": 10})
180+
181+
assert embeddings.id is not None
182+
assert len(embeddings.data) == 1
183+
assert len(embeddings.data[0].embedding) == 4096
184+
assert embeddings.usage.completion_tokens == 0
185+
assert embeddings.usage.prompt_tokens == 10
186+
assert embeddings.usage.total_tokens == 10
187+
188+
189+
@pytest.mark.asyncio
190+
@pytest.mark.parametrize(
191+
"model_name",
192+
[EMBEDDING_MODEL_NAME],
193+
)
194+
async def test_single_embedding_truncation_invalid(
195+
embedding_client: openai.AsyncOpenAI, model_name: str):
196+
input_texts = [
197+
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
198+
]
199+
200+
with pytest.raises(openai.BadRequestError):
201+
embeddings = await embedding_client.embeddings.create(
202+
model=model_name,
203+
input=input_texts,
204+
extra_body={"truncate_prompt_tokens": 8193})
205+
assert "error" in embeddings.object
206+
assert "truncate_prompt_tokens value is greater than max_model_len. "\
207+
"Please, select a smaller truncation size." in embeddings.message

vllm/entrypoints/openai/protocol.py

+1
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ class EmbeddingRequest(OpenAIBaseModel):
671671
encoding_format: Literal["float", "base64"] = "float"
672672
dimensions: Optional[int] = None
673673
user: Optional[str] = None
674+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
674675

675676
# doc: begin-embedding-pooling-params
676677
additional_data: Optional[Any] = None

vllm/entrypoints/openai/serving_embedding.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ async def create_embedding(
110110
request_id = f"embd-{random_uuid()}"
111111
created_time = int(time.monotonic())
112112

113+
truncate_prompt_tokens = None
114+
115+
if request.truncate_prompt_tokens is not None:
116+
if request.truncate_prompt_tokens <= self.max_model_len:
117+
truncate_prompt_tokens = request.truncate_prompt_tokens
118+
else:
119+
return self.create_error_response(
120+
"truncate_prompt_tokens value is "
121+
"greater than max_model_len."
122+
" Please, select a smaller truncation size.")
123+
113124
# Schedule the request and get the result generator.
114125
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
115126
try:
@@ -123,11 +134,9 @@ async def create_embedding(
123134
pooling_params = request.to_pooling_params()
124135

125136
prompts = list(
126-
self._tokenize_prompt_input_or_inputs(
127-
request,
128-
tokenizer,
129-
request.input,
130-
))
137+
self._tokenize_prompt_input_or_inputs(request, tokenizer,
138+
request.input,
139+
truncate_prompt_tokens))
131140

132141
for i, prompt_inputs in enumerate(prompts):
133142
request_id_item = f"{request_id}-{i}"

0 commit comments

Comments
 (0)