|
28 | 28 | ProviderModel, |
29 | 29 | Session, |
30 | 30 | WorkspaceRow, |
| 31 | + WorkspaceWithModel, |
31 | 32 | WorkspaceWithSessionInfo, |
32 | 33 | ) |
33 | 34 | from codegate.db.token_usage import TokenUsageParser |
@@ -129,6 +130,7 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option |
129 | 130 | active_workspace = await DbReader().get_active_workspace() |
130 | 131 | workspace_id = active_workspace.id if active_workspace else "1" |
131 | 132 | prompt_params.workspace_id = workspace_id |
| 133 | + |
132 | 134 | sql = text( |
133 | 135 | """ |
134 | 136 | INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id) |
@@ -302,7 +304,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None: |
302 | 304 | await self.record_outputs(context.output_responses, initial_id) |
303 | 305 | await self.record_alerts(context.alerts_raised, initial_id) |
304 | 306 | logger.info( |
305 | | - f"Recorded context in DB. Output chunks: {len(context.output_responses)}. " |
| 307 | + f"Updated context in DB. Output chunks: {len(context.output_responses)}. " |
306 | 308 | f"Alerts: {len(context.alerts_raised)}." |
307 | 309 | ) |
308 | 310 | except Exception as e: |
@@ -720,6 +722,23 @@ async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: |
720 | 722 | ) |
721 | 723 | return workspaces[0] if workspaces else None |
722 | 724 |
|
| 725 | + async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceWithModel]: |
| 726 | + sql = text( |
| 727 | + """ |
| 728 | + SELECT |
| 729 | + w.id, w.name, m.provider_model_name |
| 730 | + FROM workspaces w |
| 731 | + JOIN muxes m ON w.id = m.workspace_id |
| 732 | + WHERE m.provider_endpoint_id = :provider_id |
| 733 | + AND w.deleted_at IS NULL |
| 734 | + """ |
| 735 | + ) |
| 736 | + conditions = {"provider_id": provider_id} |
| 737 | + workspaces = await self._exec_select_conditions_to_pydantic( |
| 738 | + WorkspaceWithModel, sql, conditions, should_raise=True |
| 739 | + ) |
| 740 | + return workspaces |
| 741 | + |
723 | 742 | async def get_archived_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: |
724 | 743 | sql = text( |
725 | 744 | """ |
|
0 commit comments