Skip to content

Commit

Permalink
feat: Configurable model for embeddings
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <diwank.singh@gmail.com>
  • Loading branch information
creatorrr committed Sep 20, 2024
1 parent 5e8f4c8 commit 44e66d0
Show file tree
Hide file tree
Showing 16 changed files with 643 additions and 598 deletions.
7 changes: 3 additions & 4 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from beartype import beartype
from temporalio import activity

from ..clients import cozo
from ..clients import vertexai as embedder
from ..clients import cozo, litellm
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand All @@ -14,8 +13,8 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await embedder.embed(
[
embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
Expand Down
55 changes: 51 additions & 4 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
from functools import wraps
from typing import List
from typing import List, Literal

from beartype import beartype
from litellm import acompletion as _acompletion
from litellm import get_supported_openai_params
from litellm import (
acompletion as _acompletion,
aembedding as _aembedding,
get_supported_openai_params,
)
import litellm
from litellm.utils import CustomStreamWrapper, ModelResponse

from ..env import litellm_master_key, litellm_url
from ..env import embedding_model_id, embedding_dimensions, litellm_master_key, litellm_url

__all__: List[str] = ["acompletion"]

# TODO: Should check if this is really needed
litellm.drop_params = True


@wraps(_acompletion)
@beartype
async def acompletion(
*, model: str, messages: list[dict], custom_api_key: None | str = None, **kwargs
) -> ModelResponse | CustomStreamWrapper:
if not custom_api_key:
model = f"openai/{model}" # FIXME: This is for litellm

supported_params = get_supported_openai_params(model)
settings = {k: v for k, v in kwargs.items() if k in supported_params}

Expand All @@ -26,3 +36,40 @@ async def acompletion(
base_url=None if custom_api_key else litellm_url,
api_key=custom_api_key or litellm_master_key,
)


@wraps(_aembedding)
@beartype
async def aembedding(
*,
inputs: str |list[str],
model: str = embedding_model_id,
dimensions: int = embedding_dimensions,
join_inputs: bool = False,
custom_api_key: None | str = None,
**settings,
) -> list[list[float]]:
if not custom_api_key:
model = f"openai/{model}" # FIXME: This is for litellm

if isinstance(inputs, str):
input = [inputs]
else:
input = ["\n\n".join(inputs)] if join_inputs else inputs

response = await _aembedding(
model=model,
input=input,
# dimensions=dimensions, # FIXME: litellm doesn't support dimensions correctly
api_base=None if custom_api_key else litellm_url,
api_key=custom_api_key or litellm_master_key,
drop_params=True,
**settings,
)

embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data

# FIXME: Truncation should be handled by litellm
result = [embedding["embedding"][:dimensions] for embedding in embedding_list]

return result
File renamed without changes.
18 changes: 0 additions & 18 deletions agents-api/agents_api/clients/vertexai.py

This file was deleted.

14 changes: 4 additions & 10 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)


# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand All @@ -51,6 +52,7 @@

api_key_header_name: str = env.str("AGENTS_API_KEY_HEADER_NAME", default="X-Auth-Key")


# Litellm API
# -----------
litellm_url: str = env.str("LITELLM_URL", default="http://0.0.0.0:4000")
Expand All @@ -59,13 +61,11 @@

# Embedding service
# -----------------
embedding_service_base: str = env.str(
"EMBEDDING_SERVICE_BASE", default="http://0.0.0.0:8082"
)
embedding_model_id: str = env.str(
"EMBEDDING_MODEL_ID", default="Alibaba-NLP/gte-large-en-v1.5"
)
truncate_embed_text: bool = env.bool("TRUNCATE_EMBED_TEXT", default=True)

embedding_dimensions: int = env.int("EMBEDDING_DIMENSIONS", default=1024)


# Temporal
Expand All @@ -77,9 +77,6 @@
temporal_endpoint: Any = env.str("TEMPORAL_ENDPOINT", default="localhost:7233")
temporal_task_queue: Any = env.str("TEMPORAL_TASK_QUEUE", default="julep-task-queue")

# Google cloud
google_project_id: str = env.str("GOOGLE_PROJECT_ID")
vertex_location: str = env.str("VERTEX_LOCATION", default="us-central1")

# Consolidate environment variables
environment: Dict[str, Any] = dict(
Expand All @@ -94,13 +91,10 @@
api_key_header_name=api_key_header_name,
hostname=hostname,
api_prefix=api_prefix,
embedding_service_base=embedding_service_base,
truncate_embed_text=truncate_embed_text,
temporal_worker_url=temporal_worker_url,
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
testing=testing,
google_project_id=google_project_id,
)

if debug or testing:
Expand Down
9 changes: 4 additions & 5 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from agents_api.autogen.Chat import ChatInput

from ...autogen.openapi_model import DocReference, History
from ...clients import vertexai as embed
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand Down Expand Up @@ -61,12 +61,11 @@ async def gather_messages(
return past_messages, []

# Search matching docs
query_embedding = await embed.embed(
inputs=[
[query_embedding, *_] = await litellm.aembedding(
inputs="\n\n".join([
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
],
join_inputs=True,
]),
)
query_text = new_raw_messages[-1]["content"]

Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

from fastapi import Depends

import agents_api.clients.vertexai as embedder

from ...autogen.openapi_model import (
EmbedQueryRequest,
EmbedQueryResponse,
)
from ...clients import litellm
from ...dependencies.developer_id import get_developer_id
from .router import router

Expand All @@ -23,6 +22,6 @@ async def embed(
[text_to_embed] if isinstance(text_to_embed, str) else text_to_embed
)

vectors = await embedder.embed(inputs=text_to_embed)
vectors = await litellm.aembedding(inputs=text_to_embed)

return EmbedQueryResponse(vectors=vectors)
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ async def chat(
# Merge the settings and prepare environment
chat_context.merge_settings(chat_input)
settings: dict = chat_context.settings.model_dump()
settings["model"] = f"openai/{settings['model']}" # litellm proxy idiosyncracy

# Get the past messages and doc references
past_messages, doc_references = await gather_messages(
Expand Down
Loading

0 comments on commit 44e66d0

Please sign in to comment.