Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError

import codegate.muxing.models as mux_models
from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
Expand Down Expand Up @@ -477,7 +478,7 @@ async def delete_workspace_custom_instructions(workspace_name: str):
)
async def get_workspace_muxes(
workspace_name: str,
) -> List[v1_models.MuxRule]:
) -> List[mux_models.MuxRule]:
"""Get the mux rules of a workspace.

The list is ordered in order of priority. That is, the first rule in the list
Expand All @@ -501,7 +502,7 @@ async def get_workspace_muxes(
)
async def set_workspace_muxes(
workspace_name: str,
request: List[v1_models.MuxRule],
request: List[mux_models.MuxRule],
):
"""Set the mux rules of a workspace."""
try:
Expand Down
23 changes: 0 additions & 23 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,26 +267,3 @@ class ModelByProvider(pydantic.BaseModel):

def __str__(self):
return f"{self.provider_name} / {self.name}"


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.
"""

# Always match this prompt
catch_all = "catch_all"


class MuxRule(pydantic.BaseModel):
"""
Represents a mux rule for a provider.
"""

provider_id: str
model: str
# The type of matcher to use
matcher_type: MuxMatcherType
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None
27 changes: 27 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum
from typing import Optional

import pydantic


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.
"""

# Always match this prompt
catch_all = "catch_all"


class MuxRule(pydantic.BaseModel):
"""
Represents a mux rule for a provider.
"""

provider_id: str
model: str
# The type of matcher to use
matcher_type: MuxMatcherType
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None
50 changes: 19 additions & 31 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@
from typing import List, Optional, Tuple
from uuid import uuid4 as uuid

from codegate.db import models as db_models
from codegate.db.connection import DbReader, DbRecorder
from codegate.db.models import (
ActiveWorkspace,
MuxRule,
Session,
WorkspaceRow,
WorkspaceWithSessionInfo,
)
from codegate.muxing import models as mux_models
from codegate.muxing import rulematcher


Expand Down Expand Up @@ -40,7 +35,7 @@ class WorkspaceCrud:
def __init__(self):
self._db_reader = DbReader()

async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
async def add_workspace(self, new_workspace_name: str) -> db_models.WorkspaceRow:
"""
Add a workspace

Expand All @@ -57,7 +52,7 @@ async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:

async def rename_workspace(
self, old_workspace_name: str, new_workspace_name: str
) -> WorkspaceRow:
) -> db_models.WorkspaceRow:
"""
Rename a workspace

Expand All @@ -79,33 +74,33 @@ async def rename_workspace(
if not ws:
raise WorkspaceDoesNotExistError(f"Workspace {old_workspace_name} does not exist.")
db_recorder = DbRecorder()
new_ws = WorkspaceRow(
new_ws = db_models.WorkspaceRow(
id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions
)
workspace_renamed = await db_recorder.update_workspace(new_ws)
return workspace_renamed

async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
async def get_workspaces(self) -> List[db_models.WorkspaceWithSessionInfo]:
"""
Get all workspaces
"""
return await self._db_reader.get_workspaces()

async def get_archived_workspaces(self) -> List[WorkspaceRow]:
async def get_archived_workspaces(self) -> List[db_models.WorkspaceRow]:
"""
Get all archived workspaces
"""
return await self._db_reader.get_archived_workspaces()

async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
async def get_active_workspace(self) -> Optional[db_models.ActiveWorkspace]:
"""
Get the active workspace
"""
return await self._db_reader.get_active_workspace()

async def _is_workspace_active(
self, workspace_name: str
) -> Tuple[bool, Optional[Session], Optional[WorkspaceRow]]:
) -> Tuple[bool, Optional[db_models.Session], Optional[db_models.WorkspaceRow]]:
"""
Check if the workspace is active alongside the session and workspace objects
"""
Expand Down Expand Up @@ -155,13 +150,13 @@ async def recover_workspace(self, workspace_name: str):

async def update_workspace_custom_instructions(
self, workspace_name: str, custom_instr_lst: List[str]
) -> WorkspaceRow:
) -> db_models.WorkspaceRow:
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not selected_workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")

custom_instructions = " ".join(custom_instr_lst)
workspace_update = WorkspaceRow(
workspace_update = db_models.WorkspaceRow(
id=selected_workspace.id,
name=selected_workspace.name,
custom_instructions=custom_instructions,
Expand Down Expand Up @@ -217,17 +212,13 @@ async def hard_delete_workspace(self, workspace_name: str):
raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}")
return

async def get_workspace_by_name(self, workspace_name: str) -> WorkspaceRow:
async def get_workspace_by_name(self, workspace_name: str) -> db_models.WorkspaceRow:
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
return workspace

# Can't use type hints since the models are not yet defined
# Note that I'm explicitly importing the models here to avoid circular imports.
async def get_muxes(self, workspace_name: str):
from codegate.api import v1_models

async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]:
# Verify if workspace exists
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
Expand All @@ -239,7 +230,7 @@ async def get_muxes(self, workspace_name: str):
# These are already sorted by priority
for dbmux in dbmuxes:
muxes.append(
v1_models.MuxRule(
mux_models.MuxRule(
provider_id=dbmux.provider_endpoint_id,
model=dbmux.provider_model_name,
matcher_type=dbmux.matcher_type,
Expand All @@ -249,10 +240,7 @@ async def get_muxes(self, workspace_name: str):

return muxes

# Can't use type hints since the models are not yet defined
async def set_muxes(self, workspace_name: str, muxes):
from codegate.api import v1_models

async def set_muxes(self, workspace_name: str, muxes: mux_models.MuxRule) -> None:
# Verify if workspace exists
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
Expand All @@ -265,7 +253,7 @@ async def set_muxes(self, workspace_name: str, muxes):
# Add the new muxes
priority = 0

muxes_with_routes: List[Tuple[v1_models.MuxRule, rulematcher.ModelRoute]] = []
muxes_with_routes: List[Tuple[mux_models.MuxRule, rulematcher.ModelRoute]] = []

# Verify all models are valid
for mux in muxes:
Expand All @@ -275,7 +263,7 @@ async def set_muxes(self, workspace_name: str, muxes):
matchers: List[rulematcher.MuxingRuleMatcher] = []

for mux, route in muxes_with_routes:
new_mux = MuxRule(
new_mux = db_models.MuxRule(
id=str(uuid()),
provider_endpoint_id=mux.provider_id,
provider_model_name=mux.model,
Expand All @@ -294,7 +282,7 @@ async def set_muxes(self, workspace_name: str, muxes):
mux_registry = await rulematcher.get_muxing_rules_registry()
await mux_registry.set_ws_rules(workspace_name, matchers)

async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute:
"""Get the routing for a mux

Note that this particular mux object is the API model, not the database model.
Expand Down Expand Up @@ -322,7 +310,7 @@ async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
auth_material=dbauth,
)

async def get_routing_for_db_mux(self, mux: MuxRule) -> rulematcher.ModelRoute:
async def get_routing_for_db_mux(self, mux: db_models.MuxRule) -> rulematcher.ModelRoute:
"""Get the routing for a mux

Note that this particular mux object is the database model, not the API model.
Expand Down
Loading