Skip to content

Commit

Permalink
embedding worker batch limit
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 1, 2024
1 parent 0028262 commit 84b3e07
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 53 deletions.
4 changes: 2 additions & 2 deletions langport/core/cluster_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ async def fetch_task_result(self, task_id: str):
await asyncio.sleep(0.01)
retry_counter += 1
# If client disconnected, stop to wait queue.
if retry_counter > 2000:
break
if retry_counter > 60 * 100:
raise ValueError("Worker task execution timeout")
else:
continue
retry_counter = 0
Expand Down
72 changes: 40 additions & 32 deletions langport/model/executor/embedding/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langport.model.executor.huggingface import HuggingfaceExecutor
from langport.protocol.worker_protocol import BaseWorkerResult, EmbeddingWorkerResult, EmbeddingsObject, UsageInfo
from langport.workers.embedding_worker import EmbeddingModelWorker
from langport.utils.itertools import batched

class HuggingfaceEmbeddingExecutor(HuggingfaceExecutor):
def __init__(
Expand Down Expand Up @@ -120,6 +121,37 @@ def _mean_pooling(self, model_output, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def inference_batch(self, prompts: List[str]) -> List[str]:
tokenizer = self.tokenizer
model = self.model
if model.__class__.__module__ + '.' + model.__class__.__name__ != 'sentence_transformers.SentenceTransformer.SentenceTransformer':
encoded_prompts = tokenizer(prompts, return_tensors="pt", padding="longest").to(self.device)
input_ids = encoded_prompts.input_ids
if model.config.is_encoder_decoder:
decoder_input_ids = torch.full(
(len(prompts), 1),
model.generation_config.decoder_start_token_id,
dtype=torch.long,
device=self.device,
)
model_output = model(input_ids, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
data = model_output.decoder_hidden_states[-1]
elif model.config.is_decoder:
model_output = model(input_ids, output_hidden_states=True)
is_chatglm = "chatglm" in str(type(model)).lower()
if is_chatglm:
data = model_output.hidden_states[-1].transpose(0, 1)
else:
data = model_output.hidden_states[-1]
else:
data = model(**encoded_prompts)
# embeddings = torch.mean(data, dim=1)
embeddings = self._mean_pooling(data, encoded_prompts['attention_mask']).cpu()
else:
embeddings = model.encode(prompts, show_progress_bar=False)
return embeddings


@torch.inference_mode()
def inference(self, worker: "EmbeddingModelWorker"):
call_interval = time.time() - self.last_call_time
Expand Down Expand Up @@ -152,47 +184,23 @@ def inference(self, worker: "EmbeddingModelWorker"):
raise Exception("Invalid prompt type...")

try:
tokenizer = self.tokenizer
model = self.model

batch_prompts = batched(prompts, worker.max_batch)
embeddings = []
for each_batch in batch_prompts:
batch_embeddings = self.inference_batch(each_batch)
embeddings.extend(batch_embeddings)
# ValueError: Asking to pad but the tokenizer does not have a padding token.
# Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)`
# or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.
if tokenizer._pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

if model.__class__.__module__ + '.' + model.__class__.__name__ != 'sentence_transformers.SentenceTransformer.SentenceTransformer':
encoded_prompts = tokenizer(prompts, return_tensors="pt", padding="longest").to(self.device)
input_ids = encoded_prompts.input_ids
if model.config.is_encoder_decoder:
decoder_input_ids = torch.full(
(len(prompts), 1),
model.generation_config.decoder_start_token_id,
dtype=torch.long,
device=self.device,
)
model_output = model(input_ids, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
data = model_output.decoder_hidden_states[-1]
elif model.config.is_decoder:
model_output = model(input_ids, output_hidden_states=True)
is_chatglm = "chatglm" in str(type(model)).lower()
if is_chatglm:
data = model_output.hidden_states[-1].transpose(0, 1)
else:
data = model_output.hidden_states[-1]
else:
data = model(**encoded_prompts)
# embeddings = torch.mean(data, dim=1)
embeddings = self._mean_pooling(data, encoded_prompts['attention_mask']).cpu()
else:
embeddings = model.encode(prompts)
if self.tokenizer._pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

for task_i, cur_task in enumerate(tasks):
token_num = 0
embedding_list = []
for prompt_i in range(len(prompts)):
if prompts_index[prompt_i] == task_i:
token_num += len(tokenizer(prompts[prompt_i]).input_ids)
token_num += len(self.tokenizer(prompts[prompt_i]).input_ids)
embedding_list.append(EmbeddingsObject(index=task_i, embedding=embeddings[prompt_i].tolist()))
worker.push_task_result(
cur_task.task_id,
Expand Down
2 changes: 1 addition & 1 deletion langport/model/executor/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def load_sentence_transformer_model(
if device == "cpu":
kwargs["torch_dtype"] = torch.float32
elif device == "cuda":
kwargs["torch_dtype"] = torch.float16
kwargs["torch_dtype"] = "auto"
if num_gpus != 1:
kwargs["device_map"] = "auto"
if max_gpu_memory is None:
Expand Down
16 changes: 12 additions & 4 deletions langport/routers/gateway/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import base64
import json

from typing import Coroutine, Generator, Optional, Union, Dict, List, Any
from typing import AsyncGenerator, Coroutine, Generator, Optional, Union, Dict, List, Any

from fastapi.responses import StreamingResponse
import httpx
Expand Down Expand Up @@ -42,7 +42,15 @@
GenerationWorkerResult,
)
from langport.core.dispatch import DispatchMethod
from langport.routers.gateway.common import LANGPORT_HEADER, AppSettings, _get_worker_address, _list_models, check_model, check_requests, create_bad_request_response
from langport.routers.gateway.common import (
LANGPORT_HEADER,
AppSettings,
_get_worker_address,
_list_models,
check_model,
check_requests,
create_bad_request_response
)

def clean_system_prompts(messages: List[Dict[str, str]]):
system_prompt = ""
Expand Down Expand Up @@ -190,7 +198,7 @@ async def generate_completion_stream_generator(app_settings: AppSettings, payloa
yield "data: [DONE]\n\n"


async def generate_completion_stream(app_settings: AppSettings, url: str, payload: Dict[str, Any]) -> Generator[GenerationWorkerResult, Any, None]:
async def generate_completion_stream(app_settings: AppSettings, url: str, payload: Dict[str, Any]) -> AsyncGenerator[GenerationWorkerResult, Any, None]:
async with httpx.AsyncClient() as client:
try:
worker_addr = await _get_worker_address(app_settings, payload["model"], "generation", client, DispatchMethod.LOTTERY)
Expand Down Expand Up @@ -244,7 +252,7 @@ async def single_completions_non_stream(app_settings: AppSettings, payload: Dict

async def chat_completion_stream_generator(
app_settings: AppSettings, payload: Dict[str, Any], n: int
) -> Generator[str, Any, None]:
) -> AsyncGenerator[str, Any, None]:
"""
Event stream format:
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
Expand Down
5 changes: 5 additions & 0 deletions langport/utils/itertools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@


def batched(data, batch_size: int):
for i in range(0, len(data), batch_size):
yield data[i:i+batch_size]
2 changes: 1 addition & 1 deletion langport/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
LANGPORT_VERSION = "0.3.10"
LANGPORT_VERSION = "0.3.11"
21 changes: 16 additions & 5 deletions langport/workers/embedding_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,27 @@ async def get_embeddings(self, task: EmbeddingsTask):
context_length = self.executor.context_length

if input_tokens > context_length:
ooc_message = f"This model's maximum context length is {context_length} tokens. "
f"However, you requested {input_tokens} tokens. "
f"Please reduce the length of the messages or completion."
self.logger.info(ooc_message)
return BaseWorkerResult(task_id=task.task_id,
type="error",
message=f"This model's maximum context length is {context_length} tokens. "
f"However, you requested {input_tokens} tokens. "
f"Please reduce the length of the messages or completion.",
message=ooc_message,
error_code=ErrorCode.CONTEXT_OVERFLOW
)

await self.add_task(task)
result = None
async for chunk in self.fetch_task_result(task.task_id):
result = chunk
try:
async for chunk in self.fetch_task_result(task.task_id):
result = chunk
except Exception as e:
self.logger.error(ooc_message)
return BaseWorkerResult(task_id=task.task_id,
type="error",
message=str(e),
error_code=ErrorCode.INTERNAL_ERROR
)

return result
24 changes: 17 additions & 7 deletions langport/workers/generation_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,30 @@ async def generation_stream(self, task: GenerationTask):
context_length = self.executor.context_length

if context_length is not None and prompt_tokens + max_tokens > context_length:
ooc_message = f"This model's maximum context length is {context_length} tokens. "
f"However, you requested {max_tokens + prompt_tokens} tokens "
f"({prompt_tokens} in the messages, "
f"{max_tokens} in the completion). "
f"Please reduce the length of the messages or completion."
self.logger.info(ooc_message)
yield BaseWorkerResult(task_id=task.task_id,
type="error",
message=f"This model's maximum context length is {context_length} tokens. "
f"However, you requested {max_tokens + prompt_tokens} tokens "
f"({prompt_tokens} in the messages, "
f"{max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
message=ooc_message,
error_code=ErrorCode.CONTEXT_OVERFLOW
)
return

await self.add_task(task)
async for chunk in self.fetch_task_result(task.task_id):
yield chunk
try:
async for chunk in self.fetch_task_result(task.task_id):
yield chunk
except Exception as e:
self.logger.error(str(e))
yield BaseWorkerResult(task_id=task.task_id,
type="error",
message=str(e),
error_code=ErrorCode.INTERNAL_ERROR
)

async def generation_bytes_stream(self, task: GenerationTask):
async for chunk in self.generation_stream(task):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "langport"
version = "0.3.10"
version = "0.3.11"
description = "A large language model serving platform."
readme = "README.md"
requires-python = ">=3.8"
Expand Down

0 comments on commit 84b3e07

Please sign in to comment.