Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

feat: Let the user add their own system prompts #643

Merged
merged 8 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions migrations/versions/a692c8b52308_add_workspace_system_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add_workspace_system_prompt

Revision ID: a692c8b52308
Revises: 5c2f3eee5f90
Create Date: 2025-01-17 16:33:58.464223

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "a692c8b52308"
down_revision: Union[str, None] = "5c2f3eee5f90"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Add column to workspaces table
op.execute("ALTER TABLE workspaces ADD COLUMN system_prompt TEXT DEFAULT NULL;")


def downgrade() -> None:
op.execute("ALTER TABLE workspaces DROP COLUMN system_prompt;")
26 changes: 20 additions & 6 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,15 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
async def add_workspace(self, workspace_name: str) -> Workspace:
"""Add a new workspace to the DB.

This handles validation and insertion of a new workspace.

It may raise a ValidationError if the workspace name is invalid.
or a AlreadyExistsError if the workspace already exists.
"""
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)

workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name, system_prompt=None)
sql = text(
"""
INSERT INTO workspaces (id, name)
Expand All @@ -275,6 +274,21 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
return added_workspace

async def update_workspace(self, workspace: Workspace) -> Workspace:
sql = text(
"""
UPDATE workspaces SET
name = :name,
system_prompt = :system_prompt
WHERE id = :id
RETURNING *
"""
)
updated_workspace = await self._execute_update_pydantic_model(
workspace, sql, should_raise=True
)
return updated_workspace

async def update_session(self, session: Session) -> Optional[Session]:
sql = text(
"""
Expand Down Expand Up @@ -392,11 +406,11 @@ async def get_workspaces(self) -> List[WorkspaceActive]:
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
return workspaces

async def get_workspace_by_name(self, name: str) -> List[Workspace]:
async def get_workspace_by_name(self, name: str) -> Optional[Workspace]:
sql = text(
"""
SELECT
id, name
id, name, system_prompt
FROM workspaces
WHERE name = :name
"""
Expand All @@ -422,7 +436,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
sql = text(
"""
SELECT
w.id, w.name, s.id as session_id, s.last_update
w.id, w.name, w.system_prompt, s.id as session_id, s.last_update
FROM sessions s
INNER JOIN workspaces w ON w.id = s.active_workspace_id
"""
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Setting(BaseModel):
class Workspace(BaseModel):
id: str
name: str
system_prompt: Optional[str]

@field_validator("name", mode="plain")
@classmethod
Expand Down Expand Up @@ -98,5 +99,6 @@ class WorkspaceActive(BaseModel):
class ActiveWorkspace(BaseModel):
id: str
name: str
system_prompt: Optional[str]
session_id: str
last_update: datetime.datetime
3 changes: 2 additions & 1 deletion src/codegate/pipeline/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
PipelineResult,
PipelineStep,
)
from codegate.pipeline.cli.commands import Version, Workspace
from codegate.pipeline.cli.commands import SystemPrompt, Version, Workspace

HELP_TEXT = """
## CodeGate CLI\n
Expand All @@ -32,6 +32,7 @@ async def codegate_cli(command):
available_commands = {
"version": Version().exec,
"workspace": Workspace().exec,
"system-prompt": SystemPrompt().exec,
}
out_func = available_commands.get(command[0])
if out_func is None:
Expand Down
Loading
Loading