Skip to content

Commit

Permalink
update embed api
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Feb 4, 2024
1 parent 197a672 commit d81c8ba
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 15 deletions.
29 changes: 23 additions & 6 deletions langport/model/executor/embedding/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from langport.model.executor.huggingface import HuggingfaceExecutor
from langport.protocol.worker_protocol import BaseWorkerResult, EmbeddingWorkerResult, UsageInfo
from langport.protocol.worker_protocol import BaseWorkerResult, EmbeddingWorkerResult, EmbeddingsObject, UsageInfo
from langport.workers.embedding_worker import EmbeddingModelWorker


Expand Down Expand Up @@ -118,7 +118,19 @@ def inference(self, worker: "EmbeddingModelWorker"):
self.wakeup()

# print(batch_size)
prompts = [task.input for task in tasks]
prompts = []
prompts_index = []
for task_i, task in enumerate(tasks):
task_input = task.input
if isinstance(task_input, str):
prompts.append(task_input)
prompts_index.append(task_i)
elif isinstance(task_input, list):
prompts.extend(task_input)
prompts_index.extend([task_i] * len(task_input))
else:
raise Exception("Invalid prompt type...")

try:
tokenizer = self.tokenizer
model = self.model
Expand All @@ -133,7 +145,7 @@ def inference(self, worker: "EmbeddingModelWorker"):
input_ids = encoded_prompts.input_ids
if model.config.is_encoder_decoder:
decoder_input_ids = torch.full(
(batch_size, 1),
(len(prompts), 1),
model.generation_config.decoder_start_token_id,
dtype=torch.long,
device=self.device,
Expand All @@ -151,14 +163,19 @@ def inference(self, worker: "EmbeddingModelWorker"):
data = model(**encoded_prompts)
# embeddings = torch.mean(data, dim=1)
embeddings = self._mean_pooling(data, encoded_prompts['attention_mask'])
for i in range(batch_size):
token_num = len(tokenizer(prompts[i]).input_ids)
for task_i in range(len(tasks)):
token_num = 0
embedding_list = []
for prompt_i in range(len(prompts)):
if prompts_index[prompt_i] == task_i:
token_num += len(tokenizer(prompts[i]).input_ids)
embedding_list.append(EmbeddingsObject(index=task_i, embedding=embeddings[prompt_i].tolist()))
worker.push_task_result(
tasks[i].task_id,
EmbeddingWorkerResult(
task_id=tasks[i].task_id,
type="data",
embedding=embeddings[i].tolist(),
embedding=embedding_list,
usage=UsageInfo(prompt_tokens=token_num, total_tokens=token_num),
)
)
Expand Down
4 changes: 2 additions & 2 deletions langport/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ class ChatCompletionStreamResponse(BaseModel):
model: str
choices: List[ChatCompletionResponseStreamChoice]


# TODO: Support List[int] and List[List[int]]
class EmbeddingsRequest(BaseModel):
model: str
input: str
input: Union[str, List[str]]
user: Optional[str] = None

class EmbeddingsData(BaseModel):
Expand Down
8 changes: 6 additions & 2 deletions langport/protocol/worker_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class BaseWorkerTask(BaseModel):

class EmbeddingsTask(BaseWorkerTask):
model: str
input: str
input: Union[str, List[str]]
user: Optional[str] = None

class GenerationTask(BaseWorkerTask):
Expand All @@ -103,8 +103,12 @@ class BaseWorkerResult(BaseModel):
message: Optional[str] = None
error_code: int = ErrorCode.OK

class EmbeddingWorkerResult(BaseWorkerResult):
class EmbeddingsObject(BaseModel):
embedding: List[float]
index: int

class EmbeddingWorkerResult(BaseWorkerResult):
embeddings: List[EmbeddingsObject]
usage: UsageInfo = None

class GenerationWorkerLogprobs(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion langport/routers/gateway/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ async def api_embeddings(app_settings: AppSettings, request: EmbeddingsRequest):
if response.type == "error":
return create_error_response(ErrorCode.INTERNAL_ERROR, response.message)
return EmbeddingsResponse(
data=[EmbeddingsData(embedding=response.embedding, index=0)],
data=[EmbeddingsData(embedding=each.embedding, index=each.index) for each in response.embeddings],
model=request.model,
usage=UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
Expand Down
6 changes: 4 additions & 2 deletions langport/service/gateway/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction
from starlette.types import ASGIApp
from starlette.requests import Request
import uvicorn

from langport.constants import LOGDIR, ErrorCode
Expand Down Expand Up @@ -48,7 +49,7 @@ async def dispatch(self, request, call_next):
return await call_next(request)

redirect_rules = None
def redirect_model_name(model:str):
def redirect_model_name(model: str):
if redirect_rules is not None:
for rule in redirect_rules:
from_model_name, to_model_name = rule.split(":")
Expand All @@ -60,7 +61,7 @@ def redirect_model_name(model:str):


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
async def validation_exception_handler(request: Request, exc):
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))


Expand All @@ -87,6 +88,7 @@ async def completions(request: CompletionRequest):

@app.post("/v1/embeddings")
async def embeddings(request: EmbeddingsRequest):
logger.info(request.json())
request.model = redirect_model_name(request.model)
response = await api_embeddings(app.app_settings, request)
return response
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"fastapi", "httpx", "numpy", "pydantic<=1.10.13", "requests",
"rich>=10.0.0", "sentencepiece", "datasets>=2.14.5", "cachetools", "asyncache",
"fastapi", "httpx", "pydantic<=1.10.13", "requests",
"rich>=10.0.0", "datasets>=2.14.5", "cachetools", "asyncache",
"shortuuid", "tokenizers>=0.14.1", "chatproto",
"transformers>=4.34.0", "uvicorn", "wandb", "tenacity>=8.2.2",
]
Expand Down

0 comments on commit d81c8ba

Please sign in to comment.