Skip to content

Commit d2a7938

Browse files
noooopmaxdebayserDarkLight1337
authored
[Frontend][1/N] Improve all pooling task | Support FP16 Embedding Base64 (Still uses fp32 by default). (#26414)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Maximilien de Bayser <maxdebayser@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 89342ce commit d2a7938

File tree

8 files changed

+312
-30
lines changed

8 files changed

+312
-30
lines changed

examples/online_serving/pooling/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
python examples/online_serving/pooling/cohere_rerank_client.py
77
```
88

9+
## Embedding embed_dtype usage
10+
11+
```bash
12+
python examples/online_serving/pooling/embedding_embed_dtype_client.py
13+
```
14+
915
## Jinaai rerank usage
1016

1117
```bash
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Example Python client for embedding API using vLLM API server
4+
NOTE:
5+
start a supported embeddings model server with `vllm serve`, e.g.
6+
vllm serve intfloat/e5-small
7+
"""
8+
9+
import argparse
10+
import base64
11+
12+
import requests
13+
import torch
14+
15+
from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE
16+
17+
18+
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
19+
headers = {"User-Agent": "Test Client"}
20+
response = requests.post(api_url, headers=headers, json=prompt)
21+
return response
22+
23+
24+
def parse_args():
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument("--host", type=str, default="localhost")
27+
parser.add_argument("--port", type=int, default=8000)
28+
parser.add_argument("--model", type=str, default="intfloat/e5-small")
29+
30+
return parser.parse_args()
31+
32+
33+
def main(args):
34+
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
35+
model_name = args.model
36+
37+
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
38+
prompt = {
39+
"model": model_name,
40+
"input": "vLLM is great!",
41+
"encoding_format": "base64",
42+
"embed_dtype": embed_dtype,
43+
}
44+
response = post_http_request(prompt=prompt, api_url=api_url)
45+
46+
embedding = []
47+
for data in response.json()["data"]:
48+
embedding.append(
49+
torch.frombuffer(
50+
base64.b64decode(data["embedding"]), dtype=torch_dtype
51+
).to(torch.float32)
52+
)
53+
embedding = torch.cat(embedding)
54+
print(embed_dtype, embedding.shape)
55+
56+
57+
if __name__ == "__main__":
58+
args = parse_args()
59+
main(args)

tests/entrypoints/pooling/openai/test_embedding.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
1515
from tests.models.utils import check_embeddings_close
1616
from tests.utils import RemoteOpenAIServer
17-
from vllm.entrypoints.openai.protocol import EmbeddingResponse
17+
from vllm.entrypoints.openai.protocol import (
18+
EMBED_DTYPE_TO_TORCH_DTYPE,
19+
EmbeddingResponse,
20+
)
1821
from vllm.transformers_utils.tokenizer import get_tokenizer
1922

2023
MODEL_NAME = "intfloat/multilingual-e5-small"
@@ -244,6 +247,75 @@ async def test_batch_base64_embedding(
244247
run_embedding_correctness_test(hf_model, input_texts, default_data)
245248

246249

250+
@pytest.mark.asyncio
251+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
252+
async def test_base64_embed_dtype(
253+
hf_model, server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
254+
):
255+
input_texts = [
256+
"The best thing about vLLM is that it supports many different models",
257+
]
258+
259+
responses_float = await client.embeddings.create(
260+
input=input_texts, model=model_name, encoding_format="float"
261+
)
262+
float_data = [d.embedding for d in responses_float.data]
263+
264+
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
265+
responses_base64 = requests.post(
266+
server.url_for("/v1/embeddings"),
267+
json={
268+
"model": model_name,
269+
"input": input_texts,
270+
"encoding_format": "base64",
271+
"embed_dtype": embed_dtype,
272+
},
273+
)
274+
275+
base64_data = []
276+
for data in responses_base64.json()["data"]:
277+
base64_data.append(
278+
torch.frombuffer(base64.b64decode(data["embedding"]), dtype=torch_dtype)
279+
.to(torch.float32)
280+
.tolist()
281+
)
282+
283+
check_embeddings_close(
284+
embeddings_0_lst=float_data,
285+
embeddings_1_lst=base64_data,
286+
name_0="float_data",
287+
name_1="base64_data",
288+
tol=1e-2,
289+
)
290+
291+
292+
@pytest.mark.asyncio
293+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
294+
async def test_base64_embed_dtype_not_supported(
295+
hf_model, server: RemoteOpenAIServer, model_name: str
296+
):
297+
input_texts = [
298+
"The best thing about vLLM is that it supports many different models",
299+
]
300+
301+
bad_embed_dtype = "bad_embed_dtype"
302+
303+
responses_base64 = requests.post(
304+
server.url_for("/v1/embeddings"),
305+
json={
306+
"model": model_name,
307+
"input": input_texts,
308+
"encoding_format": "base64",
309+
"embed_dtype": bad_embed_dtype,
310+
},
311+
)
312+
313+
assert responses_base64.status_code == 400
314+
assert responses_base64.json()["error"]["message"].startswith(
315+
f"embed_dtype={bad_embed_dtype!r} is not supported."
316+
)
317+
318+
247319
@pytest.mark.asyncio
248320
@pytest.mark.parametrize("model_name", [MODEL_NAME])
249321
async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str):

tests/entrypoints/pooling/openai/test_pooling.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import numpy as np
77
import pytest
88
import requests
9+
import torch
910

1011
from tests.models.utils import check_embeddings_close
1112
from tests.utils import RemoteOpenAIServer
12-
from vllm.entrypoints.openai.protocol import PoolingResponse
13+
from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE, PoolingResponse
1314
from vllm.transformers_utils.tokenizer import get_tokenizer
1415

1516
MODEL_NAME = "internlm/internlm2-1_8b-reward"
@@ -248,6 +249,80 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str)
248249
)
249250

250251

252+
@pytest.mark.asyncio
253+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
254+
async def test_base64_embed_dtype(server: RemoteOpenAIServer, model_name: str):
255+
input_texts = [
256+
"The best thing about vLLM is that it supports many different models",
257+
]
258+
259+
url = server.url_for("pooling")
260+
float_response = requests.post(
261+
url,
262+
json={
263+
"model": model_name,
264+
"input": input_texts,
265+
"encoding_format": "float",
266+
},
267+
)
268+
responses_float = PoolingResponse.model_validate(float_response.json())
269+
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
270+
271+
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
272+
responses_base64 = requests.post(
273+
url,
274+
json={
275+
"model": model_name,
276+
"input": input_texts,
277+
"encoding_format": "base64",
278+
"embed_dtype": embed_dtype,
279+
},
280+
)
281+
282+
base64_data = []
283+
for data in responses_base64.json()["data"]:
284+
base64_data.append(
285+
torch.frombuffer(base64.b64decode(data["data"]), dtype=torch_dtype)
286+
.to(torch.float32)
287+
.tolist()
288+
)
289+
290+
check_embeddings_close(
291+
embeddings_0_lst=float_data,
292+
embeddings_1_lst=base64_data,
293+
name_0="float_data",
294+
name_1="base64_data",
295+
tol=1e-2,
296+
)
297+
298+
299+
@pytest.mark.asyncio
300+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
301+
async def test_base64_embed_dtype_not_supported(
302+
server: RemoteOpenAIServer, model_name: str
303+
):
304+
input_texts = [
305+
"The best thing about vLLM is that it supports many different models",
306+
]
307+
308+
bad_embed_dtype = "bad_embed_dtype"
309+
310+
responses_base64 = requests.post(
311+
server.url_for("pooling"),
312+
json={
313+
"model": model_name,
314+
"input": input_texts,
315+
"encoding_format": "base64",
316+
"embed_dtype": bad_embed_dtype,
317+
},
318+
)
319+
320+
assert responses_base64.status_code == 400
321+
assert responses_base64.json()["error"]["message"].startswith(
322+
f"embed_dtype={bad_embed_dtype!r} is not supported."
323+
)
324+
325+
251326
@pytest.mark.asyncio
252327
async def test_invocations(server: RemoteOpenAIServer):
253328
input_texts = [

vllm/entrypoints/openai/protocol.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@
8383
)
8484
from vllm.utils import random_uuid, resolve_obj_by_qualname
8585

86+
EMBED_DTYPE_TO_TORCH_DTYPE = {
87+
"float32": torch.float32,
88+
"float16": torch.float16,
89+
"bfloat16": torch.bfloat16,
90+
# I'm not sure if other platforms' CPUs support the fp8 data format.
91+
# EMBED_DTYPE only uses the fp8 data representation,
92+
# does not use fp8 computation, and only occurs on the CPU.
93+
# Apologize for any possible break.
94+
"fp8_e4m3": torch.float8_e4m3fn,
95+
"fp8_e5m2": torch.float8_e5m2,
96+
}
97+
8698
logger = init_logger(__name__)
8799

88100
_LONG_INFO = torch.iinfo(torch.long)
@@ -1517,8 +1529,17 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
15171529
"through out the inference process and return in response."
15181530
),
15191531
)
1520-
normalize: bool | None = None
1521-
1532+
normalize: bool | None = Field(
1533+
default=None,
1534+
description="Whether to normalize the embeddings outputs. Default is True.",
1535+
)
1536+
embed_dtype: str = Field(
1537+
default="float32",
1538+
description=(
1539+
"What dtype to use for base64 encoding. Default to using "
1540+
"float32 for base64 encoding to match the OpenAI python client behavior."
1541+
),
1542+
)
15221543
# --8<-- [end:embedding-extra-params]
15231544

15241545
def to_pooling_params(self):
@@ -1594,7 +1615,17 @@ class EmbeddingChatRequest(OpenAIBaseModel):
15941615
"through out the inference process and return in response."
15951616
),
15961617
)
1597-
normalize: bool | None = None
1618+
normalize: bool | None = Field(
1619+
default=None,
1620+
description="Whether to normalize the embeddings outputs. Default is True.",
1621+
)
1622+
embed_dtype: str = Field(
1623+
default="float32",
1624+
description=(
1625+
"Which dtype to use for base64 encoding. Defaults to float32 "
1626+
"to match OpenAI API."
1627+
),
1628+
)
15981629
# --8<-- [end:chat-embedding-extra-params]
15991630

16001631
@model_validator(mode="before")
@@ -1639,6 +1670,14 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
16391670
"""
16401671
softmax: bool = True
16411672

1673+
embed_dtype: str = Field(
1674+
default="float32",
1675+
description=(
1676+
"What dtype to use for base64 encoding. Default to using "
1677+
"float32 for base64 encoding to match the OpenAI python client behavior."
1678+
),
1679+
)
1680+
16421681
def to_pooling_params(self):
16431682
return PoolingParams(task="encode", softmax=self.softmax)
16441683

0 commit comments

Comments
 (0)