Skip to content

Commit

Permalink
Add assistant gallery
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed May 30, 2024
1 parent 44d57f1 commit b690ae0
Show file tree
Hide file tree
Showing 46 changed files with 2,072 additions and 474 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Add chosen_assistants to User table
Revision ID: a3bfd0d64902
Revises: ec85f2b3c544
Create Date: 2024-05-26 17:22:24.834741
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "a3bfd0d64902"
down_revision = "ec85f2b3c544"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)


def downgrade() -> None:
op.drop_column("user", "chosen_assistants")
40 changes: 40 additions & 0 deletions backend/danswer/auth/noauth_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from collections.abc import Mapping
from typing import Any
from typing import cast

from danswer.auth.schemas import UserRole
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences


NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"


def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())


def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None)


def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
)
20 changes: 20 additions & 0 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import lru_cache
from uuid import UUID

from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import not_
Expand Down Expand Up @@ -398,6 +399,23 @@ def get_persona_by_id(
return persona


def check_user_can_edit_persona(user: User | None, persona: Persona) -> None:
# if user is None, assume that no-auth is turned on
if user is None:
return

# admins can edit everything
if user.role == UserRole.ADMIN:
return

# otherwise, make sure user owns persona
if persona.user_id != user.id:
raise HTTPException(
status_code=403,
detail=f"User not authorized to edit persona with ID {persona.id}",
)


def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
Expand Down Expand Up @@ -543,6 +561,8 @@ def upsert_persona(
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")

check_user_can_edit_persona(user=user, persona=persona)

persona.name = name
persona.description = description
persona.num_chunks = num_chunks
Expand Down
13 changes: 13 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)

"""
Preferences probably should be in a separate table at some point, but for now
putting here for simpicity
"""

# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.ARRAY(Integer), nullable=True
)

# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
)
Expand Down
34 changes: 34 additions & 0 deletions backend/danswer/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session

from danswer.db.chat import check_user_can_edit_persona
from danswer.db.chat import get_prompts_by_ids
from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_document_sets_by_ids
Expand Down Expand Up @@ -97,5 +98,38 @@ def create_update_persona(
return PersonaSnapshot.from_model(persona)


def update_persona_shared_users(
persona_id: int,
user_ids: list[UUID],
user: User | None,
db_session: Session,
) -> None:
"""Simplified version of `create_update_persona` which only touches the
accessibility rather than any of the logic (e.g. prompt, connected data sources,
etc.)."""
persona = fetch_persona_by_id(db_session=db_session, persona_id=persona_id)
if not persona:
raise HTTPException(
status_code=404, detail=f"Persona with ID {persona_id} not found"
)

check_user_can_edit_persona(user=user, persona=persona)

if persona.is_public:
raise HTTPException(status_code=400, detail="Cannot share public persona")

versioned_make_persona_private = fetch_versioned_implementation(
"danswer.db.persona", "make_persona_private"
)

# Privatize Persona
versioned_make_persona_private(
persona_id=persona_id,
user_ids=user_ids,
group_ids=None,
db_session=db_session,
)


def fetch_persona_by_id(db_session: Session, persona_id: int) -> Persona | None:
return db_session.scalar(select(Persona).where(Persona.id == persona_id))
2 changes: 1 addition & 1 deletion backend/danswer/server/auth_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
("/docs/oauth2-redirect", {"GET", "HEAD"}),
("/redoc", {"GET", "HEAD"}),
# should always be callable, will just return 401 if not authenticated
("/manage/me", {"GET"}),
("/me", {"GET"}),
# just returns 200 to validate that the server is up
("/health", {"GET"}),
# just returns auth type, needs to be accessible before the user is logged
Expand Down
22 changes: 22 additions & 0 deletions backend/danswer/server/features/persona/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from fastapi import APIRouter
from fastapi import Depends
from pydantic import BaseModel
Expand All @@ -14,6 +16,7 @@
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.persona import create_update_persona
from danswer.db.persona import update_persona_shared_users
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
Expand Down Expand Up @@ -119,6 +122,25 @@ def update_persona(
)


class PersonaShareRequest(BaseModel):
user_ids: list[UUID]


@basic_router.patch("/{persona_id}/share")
def share_persona(
persona_id: int,
persona_share_request: PersonaShareRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_shared_users(
persona_id=persona_id,
user_ids=persona_share_request.user_ids,
user=user,
db_session=db_session,
)


@basic_router.delete("/{persona_id}")
def delete_persona(
persona_id: int,
Expand Down
7 changes: 5 additions & 2 deletions backend/danswer/server/features/persona/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PersonaSnapshot(BaseModel):
prompts: list[PromptSnapshot]
tools: list[ToolSnapshot]
document_sets: list[DocumentSet]
users: list[UUID]
users: list[MinimalUserSnapshot]
groups: list[int]

@classmethod
Expand Down Expand Up @@ -92,7 +92,10 @@ def from_model(
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
users=[user.id for user in persona.users],
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
)

Expand Down
21 changes: 21 additions & 0 deletions backend/danswer/server/manage/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from typing import TYPE_CHECKING

from pydantic import BaseModel
from pydantic import root_validator
Expand All @@ -14,6 +15,9 @@
from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.features.persona.models import PersonaSnapshot

if TYPE_CHECKING:
from danswer.db.models import User as UserModel


class VersionResponse(BaseModel):
backend_version: str
Expand All @@ -26,13 +30,30 @@ class AuthTypeResponse(BaseModel):
requires_verification: bool


class UserPreferences(BaseModel):
chosen_assistants: list[int] | None


class UserInfo(BaseModel):
id: str
email: str
is_active: bool
is_superuser: bool
is_verified: bool
role: UserRole
preferences: UserPreferences

@classmethod
def from_model(cls, user: "UserModel") -> "UserInfo":
return cls(
id=str(user.id),
email=user.email,
is_active=user.is_active,
is_superuser=user.is_superuser,
is_verified=user.is_verified,
role=user.role,
preferences=(UserPreferences(chosen_assistants=user.chosen_assistants)),
)


class UserByEmail(BaseModel):
Expand Down
Loading

0 comments on commit b690ae0

Please sign in to comment.