1- import asyncio
21import datetime
32from typing import List , Optional , Tuple
43from uuid import uuid4 as uuid
1110 WorkspaceRow ,
1211 WorkspaceWithSessionInfo ,
1312)
13+ from codegate .muxing import rulematcher
1414
1515
1616class WorkspaceCrudError (Exception ):
@@ -37,8 +37,12 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError):
3737
3838class WorkspaceCrud :
3939
40- def __init__ (self ):
40+ def __init__ (
41+ self ,
42+ mux_registry : rulematcher .MuxingRulesinWorkspaces = rulematcher .get_muxing_rules_registry (),
43+ ):
4144 self ._db_reader = DbReader ()
45+ self ._mux_registry = mux_registry
4246
4347 async def add_workspace (self , new_workspace_name : str ) -> WorkspaceRow :
4448 """
@@ -135,6 +139,9 @@ async def activate_workspace(self, workspace_name: str):
135139 session .last_update = datetime .datetime .now (datetime .timezone .utc )
136140 db_recorder = DbRecorder ()
137141 await db_recorder .update_session (session )
142+
143+ # Ensure the mux registry is updated
144+ self ._mux_registry .set_active_workspace (workspace .id )
138145 return
139146
140147 async def recover_workspace (self , workspace_name : str ):
@@ -189,6 +196,9 @@ async def soft_delete_workspace(self, workspace_name: str):
189196 _ = await db_recorder .soft_delete_workspace (selected_workspace )
190197 except Exception :
191198 raise WorkspaceCrudError (f"Error deleting workspace { workspace_name } " )
199+
200+ # Remove the muxes from the registry
201+ del self ._mux_registry [workspace_name ]
192202 return
193203
194204 async def hard_delete_workspace (self , workspace_name : str ):
@@ -243,6 +253,8 @@ async def get_muxes(self, workspace_name: str):
243253
244254 # Can't use type hints since the models are not yet defined
245255 async def set_muxes (self , workspace_name : str , muxes ):
256+ from codegate .api import v1_models
257+
246258 # Verify if workspace exists
247259 workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
248260 if not workspace :
@@ -252,23 +264,19 @@ async def set_muxes(self, workspace_name: str, muxes):
252264 db_recorder = DbRecorder ()
253265 await db_recorder .delete_muxes_by_workspace (workspace .id )
254266
255- tasks = set ()
256-
257267 # Add the new muxes
258268 priority = 0
259269
270+ muxes_with_routes : List [Tuple [v1_models .MuxRule , rulematcher .ModelRoute ]] = []
271+
260272 # Verify all models are valid
261273 for mux in muxes :
262- dbm = await self ._db_reader .get_provider_model_by_provider_id_and_name (
263- mux .provider_id ,
264- mux .model ,
265- )
266- if not dbm :
267- raise WorkspaceCrudError (
268- f"Model { mux .model } does not exist for provider { mux .provider_id } "
269- )
274+ route = await self .get_routing_for_mux (mux )
275+ muxes_with_routes .append ((mux , route ))
270276
271- for mux in muxes :
277+ matchers : List [rulematcher .MuxingRuleMatcher ] = []
278+
279+ for mux , route in muxes_with_routes :
272280 new_mux = MuxRule (
273281 id = str (uuid ()),
274282 provider_endpoint_id = mux .provider_id ,
@@ -278,8 +286,92 @@ async def set_muxes(self, workspace_name: str, muxes):
278286 matcher_blob = mux .matcher if mux .matcher else "" ,
279287 priority = priority ,
280288 )
281- tasks .add (db_recorder .add_mux (new_mux ))
289+ dbmux = await db_recorder .add_mux (new_mux )
290+
291+ matchers .append (rulematcher .MuxingMatcherFactory .create (dbmux , route ))
282292
283293 priority += 1
284294
285- await asyncio .gather (* tasks )
295+ # Set routing list for the workspace
296+ self ._mux_registry [workspace_name ] = matchers
297+
298+ async def get_routing_for_mux (self , mux ) -> rulematcher .ModelRoute :
299+ """Get the routing for a mux
300+
301+ Note that this particular mux object is the API model, not the database model.
302+ It's only not annotated because of a circular import issue.
303+ """
304+ dbprov = await self ._db_reader .get_provider_endpoint_by_id (mux .provider_id )
305+ if not dbprov :
306+ raise WorkspaceCrudError (f"Provider { mux .provider_id } does not exist" )
307+
308+ dbm = await self ._db_reader .get_provider_model_by_provider_id_and_name (
309+ mux .provider_id ,
310+ mux .model ,
311+ )
312+ if not dbm :
313+ raise WorkspaceCrudError (
314+ f"Model { mux .model } does not exist for provider { mux .provider_id } "
315+ )
316+ dbauth = await self ._db_reader .get_auth_material_by_provider_id (mux .provider_id )
317+ if not dbauth :
318+ raise WorkspaceCrudError (f"Auth material for provider { mux .provider_id } does not exist" )
319+
320+ return rulematcher .ModelRoute (
321+ provider = dbprov ,
322+ model = dbm ,
323+ auth = dbauth ,
324+ )
325+
326+ async def get_routing_for_db_mux (self , mux : MuxRule ) -> rulematcher .ModelRoute :
327+ """Get the routing for a mux
328+
329+ Note that this particular mux object is the database model, not the API model.
330+ It's only not annotated because of a circular import issue.
331+ """
332+ dbprov = await self ._db_reader .get_provider_endpoint_by_id (mux .provider_endpoint_id )
333+ if not dbprov :
334+ raise WorkspaceCrudError (f"Provider { mux .provider_endpoint_id } does not exist" )
335+
336+ dbm = await self ._db_reader .get_provider_model_by_provider_id_and_name (
337+ mux .provider_endpoint_id ,
338+ mux .provider_model_name ,
339+ )
340+ if not dbm :
341+ raise WorkspaceCrudError (
342+ f"Model { mux .provider_model_name } does not "
343+ "exist for provider {mux.provider_endpoint_id}"
344+ )
345+ dbauth = await self ._db_reader .get_auth_material_by_provider_id (mux .provider_endpoint_id )
346+ if not dbauth :
347+ raise WorkspaceCrudError (
348+ f"Auth material for provider { mux .provider_endpoint_id } does not exist"
349+ )
350+
351+ return rulematcher .ModelRoute (
352+ model = dbm ,
353+ endpoint = dbprov ,
354+ auth_material = dbauth ,
355+ )
356+
357+ async def initialize_mux_registry (self ):
358+ """Initialize the mux registry with all workspaces in the database"""
359+
360+ active_ws = await self .get_active_workspace ()
361+ if active_ws :
362+ self ._mux_registry .set_active_workspace (active_ws .name )
363+
364+ # Get all workspaces
365+ workspaces = await self .get_workspaces ()
366+
367+ # For each workspace, get the muxes and set them in the registry
368+ for ws in workspaces :
369+ muxes = await self ._db_reader .get_muxes_by_workspace (ws .id )
370+
371+ matchers : List [rulematcher .MuxingRuleMatcher ] = []
372+
373+ for mux in muxes :
374+ route = await self .get_routing_for_db_mux (mux )
375+ matchers .append (rulematcher .MuxingMatcherFactory .create (mux , route ))
376+
377+ self ._mux_registry [ws .name ] = matchers
0 commit comments