Skip to content

Commit

Permalink
update embedding api
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jun 29, 2024
1 parent d8d54e7 commit 35c1fe5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 15 deletions.
10 changes: 5 additions & 5 deletions langport/model/executor/embedding/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,19 @@ def inference(self, worker: "EmbeddingModelWorker"):
data = model(**encoded_prompts)
# embeddings = torch.mean(data, dim=1)
embeddings = self._mean_pooling(data, encoded_prompts['attention_mask'])
for task_i in range(len(tasks)):
for task_i, cur_task in enumerate(tasks):
token_num = 0
embedding_list = []
for prompt_i in range(len(prompts)):
if prompts_index[prompt_i] == task_i:
token_num += len(tokenizer(prompts[i]).input_ids)
token_num += len(tokenizer(prompts[prompt_i]).input_ids)
embedding_list.append(EmbeddingsObject(index=task_i, embedding=embeddings[prompt_i].tolist()))
worker.push_task_result(
tasks[i].task_id,
cur_task.task_id,
EmbeddingWorkerResult(
task_id=tasks[i].task_id,
task_id=cur_task.task_id,
type="data",
embedding=embedding_list,
embeddings=embedding_list,
usage=UsageInfo(prompt_tokens=token_num, total_tokens=token_num),
)
)
Expand Down
4 changes: 3 additions & 1 deletion langport/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ class EmbeddingsRequest(BaseModel):
model: str
input: Union[str, List[str]]
user: Optional[str] = None
encoding_format: Optional[Literal["float", "base64"]] = None
dimensions: Optional[int] = None

class EmbeddingsData(BaseModel):
object: str = "embedding"
embedding: List[float]
embedding: Union[List[float], str]
index: int

class EmbeddingsResponse(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions langport/protocol/worker_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class EmbeddingsTask(BaseWorkerTask):
model: str
input: Union[str, List[str]]
user: Optional[str] = None
dimensions: Optional[int] = None

class GenerationTask(BaseWorkerTask):
prompt: str
Expand Down
40 changes: 31 additions & 9 deletions langport/routers/gateway/openai_compatible.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

import asyncio
import base64
import json

from typing import Coroutine, Generator, Optional, Union, Dict, List, Any
Expand All @@ -9,6 +10,8 @@
import httpx
import shortuuid

import numpy as np

from langport.constants import WORKER_API_TIMEOUT, ErrorCode
from langport.model.model_adapter import get_conversation_template
from langport.protocol.openai_api_protocol import (
Expand Down Expand Up @@ -452,17 +455,36 @@ async def api_embeddings(app_settings: AppSettings, request: EmbeddingsRequest):
payload = {
"model": request.model,
"input": request.input,
"dimensions": request.dimensions,
}

response = await get_embedding(app_settings, payload)
if response.type == "error":
return create_bad_request_response(ErrorCode.INTERNAL_ERROR, response.message)
return EmbeddingsResponse(
data=[EmbeddingsData(embedding=each.embedding, index=each.index) for each in response.embeddings],
model=request.model,
usage=UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
completion_tokens=None,
),
).dict(exclude_none=True)

if request.encoding_format is None or request.encoding_format == "float":
return EmbeddingsResponse(
data=[EmbeddingsData(embedding=each.embedding, index=each.index) for each in response.embeddings],
model=request.model,
usage=UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
completion_tokens=None,
),
).dict(exclude_none=True)
elif request.encoding_format == "base64":
return EmbeddingsResponse(
data=[EmbeddingsData(
embedding=base64.b64encode(np.array(each.embedding, dtype="float32").tobytes()).decode("utf-8"),
index=each.index
) for each in response.embeddings
],
model=request.model,
usage=UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
completion_tokens=None,
),
).dict(exclude_none=True)
else:
raise Exception("Invalid encoding_format param.")

0 comments on commit 35c1fe5

Please sign in to comment.