diff --git a/langport/model/executor/embedding/huggingface.py b/langport/model/executor/embedding/huggingface.py index d8cfcbb..a99cd3e 100644 --- a/langport/model/executor/embedding/huggingface.py +++ b/langport/model/executor/embedding/huggingface.py @@ -4,7 +4,7 @@ import torch from langport.model.executor.huggingface import HuggingfaceExecutor -from langport.protocol.worker_protocol import BaseWorkerResult, EmbeddingWorkerResult, UsageInfo +from langport.protocol.worker_protocol import BaseWorkerResult, EmbeddingWorkerResult, EmbeddingsObject, UsageInfo from langport.workers.embedding_worker import EmbeddingModelWorker @@ -118,7 +118,19 @@ def inference(self, worker: "EmbeddingModelWorker"): self.wakeup() # print(batch_size) - prompts = [task.input for task in tasks] + prompts = [] + prompts_index = [] + for task_i, task in enumerate(tasks): + task_input = task.input + if isinstance(task_input, str): + prompts.append(task_input) + prompts_index.append(task_i) + elif isinstance(task_input, list): + prompts.extend(task_input) + prompts_index.extend([task_i] * len(task_input)) + else: + raise Exception("Invalid prompt type...") + try: tokenizer = self.tokenizer model = self.model @@ -133,7 +145,7 @@ def inference(self, worker: "EmbeddingModelWorker"): input_ids = encoded_prompts.input_ids if model.config.is_encoder_decoder: decoder_input_ids = torch.full( - (batch_size, 1), + (len(prompts), 1), model.generation_config.decoder_start_token_id, dtype=torch.long, device=self.device, @@ -151,14 +163,19 @@ def inference(self, worker: "EmbeddingModelWorker"): data = model(**encoded_prompts) # embeddings = torch.mean(data, dim=1) embeddings = self._mean_pooling(data, encoded_prompts['attention_mask']) - for i in range(batch_size): - token_num = len(tokenizer(prompts[i]).input_ids) + for task_i in range(len(tasks)): + token_num = 0 + embedding_list = [] + for prompt_i in range(len(prompts)): + if prompts_index[prompt_i] == task_i: + token_num += len(tokenizer(prompts[i]).input_ids) + embedding_list.append(EmbeddingsObject(index=task_i, embedding=embeddings[prompt_i].tolist())) worker.push_task_result( tasks[i].task_id, EmbeddingWorkerResult( task_id=tasks[i].task_id, type="data", - embedding=embeddings[i].tolist(), + embedding=embedding_list, usage=UsageInfo(prompt_tokens=token_num, total_tokens=token_num), ) ) diff --git a/langport/protocol/openai_api_protocol.py b/langport/protocol/openai_api_protocol.py index 895a0fb..8dc83d4 100644 --- a/langport/protocol/openai_api_protocol.py +++ b/langport/protocol/openai_api_protocol.py @@ -115,10 +115,10 @@ class ChatCompletionStreamResponse(BaseModel): model: str choices: List[ChatCompletionResponseStreamChoice] - +# TODO: Support List[int] and List[List[int]] class EmbeddingsRequest(BaseModel): model: str - input: str + input: Union[str, List[str]] user: Optional[str] = None class EmbeddingsData(BaseModel): diff --git a/langport/protocol/worker_protocol.py b/langport/protocol/worker_protocol.py index 241a631..4464c48 100644 --- a/langport/protocol/worker_protocol.py +++ b/langport/protocol/worker_protocol.py @@ -76,7 +76,7 @@ class BaseWorkerTask(BaseModel): class EmbeddingsTask(BaseWorkerTask): model: str - input: str + input: Union[str, List[str]] user: Optional[str] = None class GenerationTask(BaseWorkerTask): @@ -103,8 +103,12 @@ class BaseWorkerResult(BaseModel): message: Optional[str] = None error_code: int = ErrorCode.OK -class EmbeddingWorkerResult(BaseWorkerResult): +class EmbeddingsObject(BaseModel): embedding: List[float] + index: int + +class EmbeddingWorkerResult(BaseWorkerResult): + embeddings: List[EmbeddingsObject] usage: UsageInfo = None class GenerationWorkerLogprobs(BaseModel): diff --git a/langport/routers/gateway/openai_compatible.py b/langport/routers/gateway/openai_compatible.py index c147cf0..d7e2490 100644 --- a/langport/routers/gateway/openai_compatible.py +++ b/langport/routers/gateway/openai_compatible.py @@ -445,7 +445,7 @@ async def api_embeddings(app_settings: AppSettings, request: EmbeddingsRequest): if response.type == "error": return create_error_response(ErrorCode.INTERNAL_ERROR, response.message) return EmbeddingsResponse( - data=[EmbeddingsData(embedding=response.embedding, index=0)], + data=[EmbeddingsData(embedding=each.embedding, index=each.index) for each in response.embeddings], model=request.model, usage=UsageInfo( prompt_tokens=response.usage.prompt_tokens, diff --git a/langport/service/gateway/openai_api.py b/langport/service/gateway/openai_api.py index 9b1bf22..36fce05 100644 --- a/langport/service/gateway/openai_api.py +++ b/langport/service/gateway/openai_api.py @@ -11,6 +11,7 @@ from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction from starlette.types import ASGIApp +from starlette.requests import Request import uvicorn from langport.constants import LOGDIR, ErrorCode @@ -48,7 +49,7 @@ async def dispatch(self, request, call_next): return await call_next(request) redirect_rules = None -def redirect_model_name(model:str): +def redirect_model_name(model: str): if redirect_rules is not None: for rule in redirect_rules: from_model_name, to_model_name = rule.split(":") @@ -60,7 +61,7 @@ def redirect_model_name(model:str): @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request, exc): +async def validation_exception_handler(request: Request, exc): return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) @@ -87,6 +88,7 @@ async def completions(request: CompletionRequest): @app.post("/v1/embeddings") async def embeddings(request: EmbeddingsRequest): + logger.info(request.json()) request.model = redirect_model_name(request.model) response = await api_embeddings(app.app_settings, request) return response diff --git a/pyproject.toml b/pyproject.toml index cf91fb5..728f025 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,8 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", ] dependencies = [ - "fastapi", "httpx", "numpy", "pydantic<=1.10.13", "requests", - "rich>=10.0.0", "sentencepiece", "datasets>=2.14.5", "cachetools", "asyncache", + "fastapi", "httpx", "pydantic<=1.10.13", "requests", + "rich>=10.0.0", "datasets>=2.14.5", "cachetools", "asyncache", "shortuuid", "tokenizers>=0.14.1", "chatproto", "transformers>=4.34.0", "uvicorn", "wandb", "tenacity>=8.2.2", ]