Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ jobs:

- name: Upload coverage Report to Codecov for python 3.10
if: ${{ matrix.python-version == '3.10' && inputs.upload_coverage == true }}
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ Embedding Inference Server - finding TGI for embeddings
## Why Infinity:
Infinity provides the following features:
- **Fast inference**: The inference server is built on top of [torch](https:) and [ctranslate2](https://github.com/OpenNMT/CTranslate2) under the hood, getting most out of your **CUDA** or **CPU** hardware.
- **Continous batching**: All new embedding requests are queued while GPU is busy with the previous ones. New requests are served as soon as GPU is ready. Adds only ~2% overhead for large datasets, over static batching.
- **Dynamic batching**: New embedding requests are queued while GPU is busy with the previous ones. New requests are squeezed intro your GPU/CPU as soon as ready.
- **Correct and tested implementation**: Unit and end-to-end tested. API embeddings are identical to [sentence-transformers](https://github.com/UKPLab/sentence-transformers/) (up to numerical precision). Lets API users create embeddings till infinity and beyond.
- **Easy to use**: The API is built on top of [FastAPI](https://fastapi.tiangolo.com/) and [Swagger](https://swagger.io/) and is fully documented. See below on how to get started.
- **Easy to use**: The API is built on top of [FastAPI](https://fastapi.tiangolo.com/), [Swagger](https://swagger.io/) makes it fully documented. API specs are aligned to OpenAI. See below on how to get started.

# Demo:
A quick demo of launching: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) with batch-size=2 and sending 3 requests via cURL.
Expand Down
21 changes: 4 additions & 17 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/convert.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
from ..inference.primitives import NpEmbeddingType
from .pymodels import OpenAIEmbeddingResult, _EmbeddingObject, _Usage
from .pymodels import OpenAIEmbeddingResult


def list_embeddings_to_response(
embeddings: NpEmbeddingType, model: str, usage: int
) -> OpenAIEmbeddingResult:
return OpenAIEmbeddingResult(
return dict(
model=model,
data=[
_EmbeddingObject(
dict(
object="embedding",
embedding=emb,
index=count,
)
for count, emb in enumerate(embeddings)
],
usage=_Usage(prompt_tokens=usage, total_tokens=usage),
usage=dict(prompt_tokens=usage, total_tokens=usage),
)

# return {
# "model": model,
# "data": [
# dict(
# object="embedding",
# embedding=emb,
# index=count,
# )
# for count, emb in enumerate(embeddings)
# ],
# "usage": {"prompt_tokens": usage, "total_tokens": usage},
# }
16 changes: 16 additions & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
FASTAPI_TITLE = "♾️ Infinity - Embedding Inference Server"
FASTAPI_SUMMARY = "Embedding Inference Server - finding TGI for embeddings"


def startup_message(host: str, port: str, prefix: str) -> str:
return f"""

♾️ Infinity - Embedding Inference Server
MIT License; Copyright (c) 2023 Michael Feil

Open the Docs via Swagger UI:
http://{host}:{port}/docs

Access model via 'GET':
curl http://{host}:{port}{prefix}/models
"""
20 changes: 12 additions & 8 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, List, Union

from ..log_handler import logger
from .models import BaseTransformer
from .models import BaseTransformer, get_lengths_with_tokenize
from .primitives import (
EmbeddingResult,
NpEmbeddingType,
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self._queue_prio = CustomPrioQueue()
self._result_store = ResultKVStore()
self._feature_queue: queue.Queue = queue.Queue(4)
self._postprocess_queue: queue.Queue = queue.Queue(5)
self._postprocess_queue: queue.Queue = queue.Queue(4)
self.max_batch_size = max_batch_size
self.model = model
self.max_queue_wait = max_queue_wait
Expand All @@ -128,9 +128,7 @@ def __init__(
def shutdown(self):
self._shutdown.set()

async def schedule(
self, sentences: List[str], prios: List[int]
) -> NpEmbeddingType | None:
async def schedule(self, sentences: List[str]) -> tuple[List[NpEmbeddingType], int]:
"""Schedule a sentence to be embedded. Awaits until embedded.

Args:
Expand All @@ -143,6 +141,11 @@ async def schedule(
# add an unique identifier
uuid_event = []
prioqueue = []

prios, usage = get_lengths_with_tokenize(
sentences
) # , self.model.tokenize_lengths)

for s, p in zip(sentences, prios):
inner = EmbeddingResult(sentence=s, event=EventTS(self._threadpool))
item = PrioritizedQueueItem(item=inner, priority=p)
Expand All @@ -154,7 +157,8 @@ async def schedule(
self._result_store.wait_for_response(uuid, event)
for uuid, event in uuid_event
]
return await asyncio.gather(*gather_results)
embeddings = await asyncio.gather(*gather_results)
return embeddings, usage

def is_overloaded(self) -> bool:
# start consuming
Expand All @@ -176,7 +180,7 @@ def overload_status(self) -> OverloadStatus:
def _preprocess_batch(self):
"""loops and checks if the _core_batch has worked on all items"""
self._ready = True
logger.info("ready to receive requests.")
logger.info("ready to batch requests.")
try:
while not self._shutdown.is_set():
# patience:
Expand Down Expand Up @@ -264,7 +268,7 @@ async def _postprocess_batch(self):
except queue.Empty:
# 7 ms, assuming this is below
# 3-50ms for inference on avg.
await asyncio.sleep(7e-3)
await asyncio.sleep(5e-3)
continue
embed, batch = post_batch
embeddings = self.model.encode_post(embed).tolist()
Expand Down
42 changes: 16 additions & 26 deletions libs/infinity_emb/infinity_emb/inference/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -70,9 +71,12 @@ def __init__(self, *args, **kwargs):
device = self._target_device
self.eval()
self.to(device)
# make a copy of the tokenizer,
# to be able to could the tokens in another thread
# without corrupting the original.
self._infinity_tokenizer = copy.deepcopy(self._first_module().tokenizer)

def encode_pre(self, sentences) -> Dict[str, Tensor]:
# features = self._tokenize_actual(sentences)
features = self.tokenize(sentences)

return features
Expand All @@ -81,51 +85,37 @@ def encode_core(self, features: Dict[str, Tensor]) -> Tensor:
"""
Computes sentence embeddings
"""
# features = self._tokenize_actual(features)
device = self._target_device
features = util.batch_to_device(features, device)
# move forward

with torch.no_grad():
out_features = self.forward(features)
with torch.inference_mode():
device = self._target_device
features = util.batch_to_device(features, device)
out_features = self.forward(features)["sentence_embedding"]

return out_features["sentence_embedding"].detach().cpu()
return out_features

def encode_post(
self, out_features: Tensor, normalize_embeddings: bool = True
) -> NpEmbeddingType:
with torch.no_grad():
embeddings = out_features
with torch.inference_mode():
embeddings = out_features.detach().cpu()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
embeddings_out: np.ndarray = embeddings.cpu().numpy()
embeddings_out: np.ndarray = embeddings.numpy()

return embeddings_out

def tokenize_lengths(self, sentences: List[str]) -> List[int]:
fm = self._first_module()
tks = fm.tokenizer.batch_encode_plus(
tks = self._infinity_tokenizer.batch_encode_plus(
sentences,
add_special_tokens=False,
return_token_type_ids=False,
return_attention_mask=False,
return_length=False,
# max_length=self._infinity_tokenizer.model_max_length,
# truncation="longest_first",
).encodings
return [len(t.tokens) for t in tks]

def _tokenize_actual(self, sentences: List[str]):
fm = self._first_module()
output = fm.tokenizer(
sentences,
padding=True,
truncation="longest_first",
return_tensors="pt",
max_length=fm.tokenizer.model_max_length,
# pad_to_multiple_of=16,
)

return dict(**output)


class CT2SentenceTransformer(SentenceTransformerPatched):
"""
Expand Down
57 changes: 33 additions & 24 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@

import typer
import uvicorn
from fastapi import FastAPI, status
from fastapi import FastAPI, responses, status
from prometheus_fastapi_instrumentator import Instrumentator

# prometheus
import infinity_emb
from infinity_emb.fastapi_schemas import errors
from infinity_emb.fastapi_schemas import docs, errors
from infinity_emb.fastapi_schemas.convert import list_embeddings_to_response
from infinity_emb.fastapi_schemas.pymodels import (
ModelInfo,
OpenAIEmbeddingInput,
OpenAIEmbeddingResult,
OpenAIModelInfo,
Expand All @@ -26,10 +25,11 @@ def create_server(
batch_size: int = 64,
engine: models.InferenceEngine = models.InferenceEngine.torch,
verbose: bool = False,
doc_extra: dict = {},
):
app = FastAPI(
title="♾️ Infinity - Embedding Inference Server",
summary="Embedding Inference Server - finding TGI for embeddings",
title=docs.FASTAPI_TITLE,
summary=docs.FASTAPI_SUMMARY,
version=infinity_emb.__version__,
contact=dict(name="Michael Feil"),
docs_url="/docs",
Expand All @@ -53,10 +53,17 @@ async def _startup():
app.batch_handler = BatchHandler(
max_batch_size=batch_size, model=model, threadpool=app.tp, verbose=verbose
)
app.tokenize_len = model.tokenize_lengths
# start in a threadpool
await app.batch_handler.spawn()

logger.info(
docs.startup_message(
host=doc_extra.pop("host", "localhost"),
port=doc_extra.pop("port", "PORT"),
prefix=url_prefix,
)
)

@app.on_event("shutdown")
async def _shutdown():
app.batch_handler.shutdown()
Expand All @@ -71,23 +78,32 @@ async def _ready() -> float:
"model not ready", code=status.HTTP_503_SERVICE_UNAVAILABLE
)

@app.get(f"{url_prefix}/models")
async def _models() -> OpenAIModelInfo:
@app.get(
f"{url_prefix}/models",
response_model=OpenAIModelInfo,
response_class=responses.ORJSONResponse,
)
async def _models():
"""get models endpoint"""
s = app.batch_handler.overload_status() # type: ignore
return OpenAIModelInfo(
data=ModelInfo(
return dict(
data=dict(
id=model_name_or_path,
stats=dict(
queue_fraction=s.queue_fraction,
queue_absolute=s.queue_absolute,
results_pending=s.results_absolute,
batch_size=batch_size,
),
)
)

@app.post(f"{url_prefix}/embeddings")
async def _embeddings(data: OpenAIEmbeddingInput) -> OpenAIEmbeddingResult:
@app.post(
f"{url_prefix}/embeddings",
response_model=OpenAIEmbeddingResult,
response_class=responses.ORJSONResponse,
)
async def _embeddings(data: OpenAIEmbeddingInput):
"""Encode Embeddings

```python
Expand All @@ -102,25 +118,16 @@ async def _embeddings(data: OpenAIEmbeddingInput) -> OpenAIEmbeddingResult:
)

try:
logger.debug("[📝] Received request with %s inputs ", len(data.input))
start = time.perf_counter()

# lengths, usage = await to_thread(
# models.get_lengths_with_tokenize, app.tp, data.input, app.tokenize_len)
lengths, usage = models.get_lengths_with_tokenize(
data.input # , app.tokenize_len
)
logger.debug("[📝] Received request with %s inputs ", len(lengths))

# emb = await asyncio.gather(
# *[(bh.schedule(s, prio=prio)) for s, prio in zip(data.input, lengths)]
# )
emb = await bh.schedule(data.input, prios=lengths)
embedding, usage = await bh.schedule(data.input)

duration = (time.perf_counter() - start) * 1000
logger.debug("[✅] Done in %s ms", duration)

res = list_embeddings_to_response(
embeddings=emb, model=data.model, usage=usage
embeddings=embedding, model=data.model, usage=usage
)

return res
Expand Down Expand Up @@ -165,6 +172,7 @@ def start_uvicorn(
batch_size=batch_size,
engine=engine_load,
verbose=log_level.to_int() <= 10,
doc_extra=dict(host=host, port=port),
)
uvicorn.run(app, host=host, port=port, log_level=log_level.name)

Expand All @@ -174,6 +182,7 @@ def cli():
typer.run(start_uvicorn)


# app = create_server()
if __name__ == "__main__":
# for debugging
cli()
Loading