Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/online_serving/pooling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
python examples/online_serving/pooling/cohere_rerank_client.py
```

## Embedding embed_dtype usage

```bash
python examples/online_serving/pooling/embedding_embed_dtype_client.py
```

## Jinaai rerank usage

```bash
Expand Down
59 changes: 59 additions & 0 deletions examples/online_serving/pooling/embedding_embed_dtype_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API using vLLM API server
NOTE:
start a supported embeddings model server with `vllm serve`, e.g.
vllm serve intfloat/e5-small
"""

import argparse
import base64

import requests
import torch

from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="intfloat/e5-small")

return parser.parse_args()


def main(args):
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
model_name = args.model

for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
prompt = {
"model": model_name,
"input": "vLLM is great!",
"encoding_format": "base64",
"embed_dtype": embed_dtype,
}
response = post_http_request(prompt=prompt, api_url=api_url)

embedding = []
for data in response.json()["data"]:
embedding.append(
torch.frombuffer(
base64.b64decode(data["embedding"]), dtype=torch_dtype
).to(torch.float32)
)
embedding = torch.cat(embedding)
print(embed_dtype, embedding.shape)


if __name__ == "__main__":
args = parse_args()
main(args)
74 changes: 73 additions & 1 deletion tests/entrypoints/pooling/openai/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE,
EmbeddingResponse,
)
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "intfloat/multilingual-e5-small"
Expand Down Expand Up @@ -244,6 +247,75 @@ async def test_batch_base64_embedding(
run_embedding_correctness_test(hf_model, input_texts, default_data)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype(
hf_model, server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]

responses_float = await client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float"
)
float_data = [d.embedding for d in responses_float.data]

for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": embed_dtype,
},
)

base64_data = []
for data in responses_base64.json()["data"]:
base64_data.append(
torch.frombuffer(base64.b64decode(data["embedding"]), dtype=torch_dtype)
.to(torch.float32)
.tolist()
)

check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float_data",
name_1="base64_data",
tol=1e-2,
)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype_not_supported(
hf_model, server: RemoteOpenAIServer, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]

bad_embed_dtype = "bad_embed_dtype"

responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": bad_embed_dtype,
},
)

assert responses_base64.status_code == 400
assert responses_base64.json()["error"]["message"].startswith(
f"embed_dtype={bad_embed_dtype!r} is not supported."
)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str):
Expand Down
77 changes: 76 additions & 1 deletion tests/entrypoints/pooling/openai/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import numpy as np
import pytest
import requests
import torch

from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import PoolingResponse
from vllm.entrypoints.openai.protocol import EMBED_DTYPE_TO_TORCH_DTYPE, PoolingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "internlm/internlm2-1_8b-reward"
Expand Down Expand Up @@ -248,6 +249,80 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str)
)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype(server: RemoteOpenAIServer, model_name: str):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]

url = server.url_for("pooling")
float_response = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "float",
},
)
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]

for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
responses_base64 = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": embed_dtype,
},
)

base64_data = []
for data in responses_base64.json()["data"]:
base64_data.append(
torch.frombuffer(base64.b64decode(data["data"]), dtype=torch_dtype)
.to(torch.float32)
.tolist()
)

check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float_data",
name_1="base64_data",
tol=1e-2,
)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype_not_supported(
server: RemoteOpenAIServer, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]

bad_embed_dtype = "bad_embed_dtype"

responses_base64 = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": bad_embed_dtype,
},
)

assert responses_base64.status_code == 400
assert responses_base64.json()["error"]["message"].startswith(
f"embed_dtype={bad_embed_dtype!r} is not supported."
)


@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
input_texts = [
Expand Down
45 changes: 42 additions & 3 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@
)
from vllm.utils import random_uuid, resolve_obj_by_qualname

EMBED_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
}

logger = init_logger(__name__)

_LONG_INFO = torch.iinfo(torch.long)
Expand Down Expand Up @@ -1517,8 +1529,17 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
"through out the inference process and return in response."
),
)
normalize: bool | None = None

normalize: bool | None = Field(
default=None,
description="Whether to normalize the embeddings outputs. Default is True.",
)
embed_dtype: str = Field(
default="float32",
description=(
"What dtype to use for base64 encoding. Default to using "
"float32 for base64 encoding to match the OpenAI python client behavior."
),
)
# --8<-- [end:embedding-extra-params]

def to_pooling_params(self):
Expand Down Expand Up @@ -1594,7 +1615,17 @@ class EmbeddingChatRequest(OpenAIBaseModel):
"through out the inference process and return in response."
),
)
normalize: bool | None = None
normalize: bool | None = Field(
default=None,
description="Whether to normalize the embeddings outputs. Default is True.",
)
embed_dtype: str = Field(
default="float32",
description=(
"Which dtype to use for base64 encoding. Defaults to float32 "
"to match OpenAI API."
),
)
# --8<-- [end:chat-embedding-extra-params]

@model_validator(mode="before")
Expand Down Expand Up @@ -1639,6 +1670,14 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
"""
softmax: bool = True

embed_dtype: str = Field(
default="float32",
description=(
"What dtype to use for base64 encoding. Default to using "
"float32 for base64 encoding to match the OpenAI python client behavior."
),
)

def to_pooling_params(self):
return PoolingParams(task="encode", softmax=self.softmax)

Expand Down
Loading