Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Create mux endpoint and route to appropriate destination provider #827

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,25 +107,12 @@ class PartialQuestions(pydantic.BaseModel):
type: QuestionType


class ProviderType(str, Enum):
"""
Represents the different types of providers we support.
"""

openai = "openai"
anthropic = "anthropic"
vllm = "vllm"
ollama = "ollama"
lm_studio = "lm_studio"
llamacpp = "llamacpp"


class TokenUsageByModel(pydantic.BaseModel):
"""
Represents the tokens used by a model.
"""

provider_type: ProviderType
provider_type: db_models.ProviderType
model: str
token_usage: db_models.TokenUsage

Expand Down Expand Up @@ -221,7 +208,7 @@ class ProviderEndpoint(pydantic.BaseModel):
id: Optional[str] = ""
name: str
description: str = ""
provider_type: ProviderType
provider_type: db_models.ProviderType
endpoint: str = "" # Some providers have defaults we can leverage
auth_type: Optional[ProviderAuthType] = ProviderAuthType.none

Expand Down
21 changes: 21 additions & 0 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GetPromptWithOutputsRow,
GetWorkspaceByNameConditions,
MuxRule,
MuxRuleProviderEndpoint,
Output,
Prompt,
ProviderAuthMaterial,
Expand Down Expand Up @@ -777,6 +778,26 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
)
return muxes

async def get_muxes_with_provider_by_workspace(
self, workspace_id: str
) -> List[MuxRuleProviderEndpoint]:
sql = text(
"""
SELECT m.id, m.provider_endpoint_id, m.provider_model_name, m.workspace_id,
m.matcher_type, m.matcher_blob, m.priority, m.created_at, m.updated_at,
pe.provider_type, pe.endpoint, pe.auth_type, pe.auth_blob
FROM muxes m
INNER JOIN provider_endpoints pe ON pe.id = m.provider_endpoint_id
WHERE m.workspace_id = :workspace_id
ORDER BY priority ASC
"""
)
conditions = {"workspace_id": workspace_id}
muxes = await self._exec_select_conditions_to_pydantic(
MuxRuleProviderEndpoint, sql, conditions, should_raise=True
)
return muxes


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
Expand Down
30 changes: 30 additions & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from enum import Enum
from typing import Annotated, Any, Dict, Optional

from pydantic import BaseModel, StringConstraints
Expand Down Expand Up @@ -116,6 +117,19 @@ class Session(BaseModel):
# Models for select queries


class ProviderType(str, Enum):
"""
Represents the different types of providers we support.
"""

openai = "openai"
anthropic = "anthropic"
vllm = "vllm"
ollama = "ollama"
lm_studio = "lm_studio"
llamacpp = "llamacpp"


class GetPromptWithOutputsRow(BaseModel):
id: Any
timestamp: Any
Expand Down Expand Up @@ -185,3 +199,19 @@ class MuxRule(BaseModel):
priority: int
created_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None


class MuxRuleProviderEndpoint(BaseModel):
id: str
provider_endpoint_id: str
provider_model_name: str
workspace_id: str
matcher_type: str
matcher_blob: str
priority: int
created_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None
provider_type: ProviderType
endpoint: str
auth_type: str
auth_blob: str
212 changes: 212 additions & 0 deletions src/codegate/mux/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import copy
import json
import uuid
from typing import Union

import structlog
from fastapi.responses import JSONResponse, StreamingResponse
from litellm import ModelResponse
from litellm.types.utils import Delta, StreamingChoices
from ollama import ChatResponse

from codegate.db import models as db_models
from codegate.providers.ollama.adapter import OLlamaToModel

logger = structlog.get_logger("codegate")


class MuxingAdapterError(Exception):
pass


class BodyAdapter:
"""
Map the body between OpenAI and Anthropic.

TODO: This are really knaive implementations. We should replace them with more robust ones.
"""

def _from_openai_to_antrhopic(self, openai_body: dict) -> dict:
"""Map the OpenAI body to the Anthropic body."""
new_body = copy.deepcopy(openai_body)
messages = new_body.get("messages", [])
system_prompt = None
system_msg_idx = None
if messages:
for i_msg, msg in enumerate(messages):
if msg.get("role", "") == "system":
system_prompt = msg.get("content")
system_msg_idx = i_msg
break
if system_prompt:
new_body["system"] = system_prompt
if system_msg_idx is not None:
del messages[system_msg_idx]
return new_body

def _from_anthropic_to_openai(self, anthropic_body: dict) -> dict:
"""Map the Anthropic body to the OpenAI body."""
new_body = copy.deepcopy(anthropic_body)
system_prompt = anthropic_body.get("system")
messages = new_body.get("messages", [])
if system_prompt:
messages.insert(0, {"role": "system", "content": system_prompt})
if "system" in new_body:
del new_body["system"]
return new_body

def _get_provider_formatted_url(
self, mux_and_provider: db_models.MuxRuleProviderEndpoint
) -> str:
"""Get the provider formatted URL to use in base_url. Note this value comes from DB"""
if mux_and_provider.provider_type == db_models.ProviderType.openai:
return f"{mux_and_provider.endpoint}/v1"
return mux_and_provider.endpoint

def _set_destination_info(
self, data: dict, mux_and_provider: db_models.MuxRuleProviderEndpoint
) -> dict:
"""Set the destination provider info."""
new_data = copy.deepcopy(data)
new_data["model"] = mux_and_provider.provider_model_name
new_data["base_url"] = self._get_provider_formatted_url(mux_and_provider)
return new_data

def _identify_provider(self, data: dict) -> db_models.ProviderType:
"""Identify the request provider."""
if "system" in data:
return db_models.ProviderType.anthropic
else:
return db_models.ProviderType.openai

def map_body_to_dest(
self, mux_and_provider: db_models.MuxRuleProviderEndpoint, data: dict
) -> dict:
"""
Map the body to the destination provider.

We only need to transform the body if the destination or origin provider is Anthropic.
"""
origin_prov = self._identify_provider(data)
if mux_and_provider.provider_type == db_models.ProviderType.anthropic:
if origin_prov != db_models.ProviderType.anthropic:
data = self._from_openai_to_antrhopic(data)
else:
if origin_prov == db_models.ProviderType.anthropic:
data = self._from_anthropic_to_openai(data)
return self._set_destination_info(data, mux_and_provider)


class StreamChunkFormatter:
"""
Format a single chunk from a stream to OpenAI format.
We need to configure the client to expect the OpenAI format.
In Continue this means setting "provider": "openai" in the config json file.
"""

def __init__(self):
self.provider_to_func = {
db_models.ProviderType.ollama: self._format_ollama,
db_models.ProviderType.openai: self._format_openai,
db_models.ProviderType.anthropic: self._format_antropic,
}

def _format_ollama(self, chunk: str) -> str:
"""Format the Ollama chunk to OpenAI format."""
try:
chunk_dict = json.loads(chunk)
ollama_chunk = ChatResponse(**chunk_dict)
open_ai_chunk = OLlamaToModel.normalize_chunk(ollama_chunk)
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
except Exception:
return chunk

def _format_openai(self, chunk: str) -> str:
"""The chunk is already in OpenAI format. To standarize remove the "data:" prefix."""
cleaned_chunk = chunk.split("data:")[1].strip()
try:
chunk_dict = json.loads(cleaned_chunk)
open_ai_chunk = ModelResponse(**chunk_dict)
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
except Exception:
return cleaned_chunk

def _format_antropic(self, chunk: str) -> str:
"""Format the Anthropic chunk to OpenAI format."""
cleaned_chunk = chunk.split("data:")[1].strip()
try:
chunk_dict = json.loads(cleaned_chunk)
msg_type = chunk_dict.get("type", "")

finish_reason = None
if msg_type == "message_stop":
finish_reason = "stop"

# In type == "content_block_start" the content comes in "content_block"
# In type == "content_block_delta" the content comes in "delta"
msg_content_dict = chunk_dict.get("delta", {}) or chunk_dict.get("content_block", {})
# We couldn't obtain the content from the chunk. Skip it.
if not msg_content_dict:
return ""

msg_content = msg_content_dict.get("text", "")
open_ai_chunk = ModelResponse(
id=f"anthropic-chat-{str(uuid.uuid4())}",
model="anthropic-muxed-model",
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=finish_reason,
index=0,
delta=Delta(content=msg_content, role="assistant"),
logprobs=None,
)
],
)
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
except Exception:
return cleaned_chunk.strip()

def format(self, chunk: str, dest_prov: db_models.ProviderType) -> ModelResponse:
"""Format the chunk to OpenAI format."""
# Get the format function
format_func = self.provider_to_func.get(dest_prov)
if format_func is None:
raise MuxingAdapterError(f"Provider {dest_prov} not supported.")
return format_func(chunk)


class ResponseAdapter:

def __init__(self):
self.stream_formatter = StreamChunkFormatter()

def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
return f"data:{formatted_chunk}\n\n"

async def _format_streaming_response(
self, response: StreamingResponse, dest_prov: db_models.ProviderType
):
"""Format the streaming response to OpenAI format."""
async for chunk in response.body_iterator:
openai_chunk = self.stream_formatter.format(chunk, dest_prov)
# Sometimes for Anthropic we couldn't get content from the chunk. Skip it.
if not openai_chunk:
continue
yield self._format_as_openai_chunk(openai_chunk)

def format_response_to_client(
self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType
) -> Union[StreamingResponse, JSONResponse]:
"""Format the response to the client."""
if isinstance(response, StreamingResponse):
return StreamingResponse(
self._format_streaming_response(response, dest_prov),
status_code=response.status_code,
headers=response.headers,
background=response.background,
media_type=response.media_type,
)
else:
raise MuxingAdapterError("Only streaming responses are supported.")
Loading