diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 82a5627aa1d63..7c7232dbccaa7 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -1,3 +1,6 @@ +import base64 + +import numpy as np import openai import pytest import ray @@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI, assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 17 assert embeddings.usage.total_tokens == 17 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, + model_name: str): + input_texts = [ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" + ] + + responses_float = await embedding_client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float") + + responses_base64 = await embedding_client.embeddings.create( + input=input_texts, model=model_name, encoding_format="base64") + + decoded_responses_base64_data = [] + for data in responses_base64.data: + decoded_responses_base64_data.append( + np.frombuffer(base64.b64decode(data.embedding), + dtype="float").tolist()) + + assert responses_float.data[0].embedding == decoded_responses_base64_data[ + 0] + assert responses_float.data[1].embedding == decoded_responses_base64_data[ + 1] diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0ad46cbea2ce6..d1568cb3a773c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel): class EmbeddingResponseData(BaseModel): index: int object: str = "embedding" - embedding: List[float] + embedding: Union[List[float], str] class EmbeddingResponse(BaseModel): diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index cbf09f173fb66..4838cb7d0255a 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,6 +1,8 @@ +import base64 import time from typing import AsyncIterator, List, Optional, Tuple +import numpy as np from fastapi import Request from vllm.config import ModelConfig @@ -20,19 +22,18 @@ def request_output_to_embedding_response( - final_res_batch: List[EmbeddingRequestOutput], - request_id: str, - created_time: int, - model_name: str, -) -> EmbeddingResponse: + final_res_batch: List[EmbeddingRequestOutput], request_id: str, + created_time: int, model_name: str, + encoding_format: str) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): assert final_res is not None prompt_token_ids = final_res.prompt_token_ids - - embedding_data = EmbeddingResponseData( - index=idx, embedding=final_res.outputs.embedding) + embedding = final_res.outputs.embedding + if encoding_format == "base64": + embedding = base64.b64encode(np.array(embedding)) + embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) num_prompt_tokens += len(prompt_token_ids) @@ -72,10 +73,8 @@ async def create_embedding(self, request: EmbeddingRequest, if error_check_ret is not None: return error_check_ret - # Return error for unsupported features. - if request.encoding_format == "base64": - return self.create_error_response( - "base64 encoding is not currently supported") + encoding_format = (request.encoding_format + if request.encoding_format else "float") if request.dimensions is not None: return self.create_error_response( "dimensions is currently not supported") @@ -129,7 +128,8 @@ async def create_embedding(self, request: EmbeddingRequest, return self.create_error_response("Client disconnected") final_res_batch[i] = res response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name) + final_res_batch, request_id, created_time, model_name, + encoding_format) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e))