Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add knowledge_brain #2988

Merged
merged 25 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ help:
dev:
DOCKER_BUILDKIT=1 docker compose -f docker-compose.dev.yml up --build

dev-build:
DOCKER_BUILDKIT=1 docker compose -f docker-compose.dev.yml build --no-cache
DOCKER_BUILDKIT=1 docker compose -f docker-compose.dev.yml up

## prod: Build and start production environment
.PHONY: prod
prod:
Expand Down
1,098 changes: 603 additions & 495 deletions backend/api/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions backend/api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ langchain-openai = "*"
langchain-cohere = "*"
# This is needed for ITO assistant
llama-parse = "^0.4.9"
pgvector = "^0.3.2"


[tool.poetry.group.dev]
Expand Down
75 changes: 4 additions & 71 deletions backend/api/quivr_api/models/settings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import os
from typing import Optional
from uuid import UUID

from langchain.embeddings.base import Embeddings
from langchain_community.embeddings.ollama import OllamaEmbeddings
from langchain_community.vectorstores.supabase import SupabaseVectorStore
from langchain_openai import OpenAIEmbeddings
from posthog import Posthog
from pydantic_settings import BaseSettings, SettingsConfigDict
from sqlalchemy import Engine, create_engine
from supabase.client import AsyncClient, Client, create_async_client, create_client

from quivr_api.logger import get_logger
from quivr_api.models.databases.supabase.supabase import SupabaseDB
from sqlalchemy import Engine
from supabase.client import AsyncClient, Client

logger = get_logger(__name__)

Expand Down Expand Up @@ -125,6 +121,7 @@ class BrainSettings(BaseSettings):
langfuse_secret_key: str | None = None
pg_database_url: str
pg_database_async_url: str
embedding_dim: int = int(os.getenv("EMBEDDING_DIM", 1536))
AmineDiro marked this conversation as resolved.
Show resolved Hide resolved


class ResendSettings(BaseSettings):
Expand All @@ -140,67 +137,3 @@ class ResendSettings(BaseSettings):
_embedding_service = None

settings = BrainSettings() # type: ignore


def get_pg_database_engine():
global _db_engine
if _db_engine is None:
logger.info("Creating Postgres DB engine")
_db_engine = create_engine(settings.pg_database_url, pool_pre_ping=True)
return _db_engine


def get_pg_database_async_engine():
global _db_engine
if _db_engine is None:
logger.info("Creating Postgres DB engine")
_db_engine = create_engine(settings.pg_database_async_url, pool_pre_ping=True)
return _db_engine


async def get_supabase_async_client() -> AsyncClient:
global _supabase_async_client
if _supabase_async_client is None:
logger.info("Creating Supabase client")
_supabase_async_client = await create_async_client(
settings.supabase_url, settings.supabase_service_key
)
return _supabase_async_client


def get_supabase_client() -> Client:
global _supabase_client
if _supabase_client is None:
logger.info("Creating Supabase client")
_supabase_client = create_client(
settings.supabase_url, settings.supabase_service_key
)
return _supabase_client


def get_supabase_db() -> SupabaseDB:
global _supabase_db
if _supabase_db is None:
logger.info("Creating Supabase DB")
_supabase_db = SupabaseDB(get_supabase_client())
return _supabase_db


def get_embedding_client() -> Embeddings:
global _embedding_service
if settings.ollama_api_base_url:
embeddings = OllamaEmbeddings(
base_url=settings.ollama_api_base_url,
) # pyright: ignore reportPrivateUsage=none
else:
embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none
return embeddings


def get_documents_vector_store() -> SupabaseVectorStore:
embeddings = get_embedding_client()
supabase_client: Client = get_supabase_client()
documents_vector_store = SupabaseVectorStore(
supabase_client, embeddings, table_name="vectors"
)
return documents_vector_store
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Optional
from uuid import UUID

from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.analytics.entity.analytics import BrainsUsages, Range, Usage
from quivr_api.modules.brain.service.brain_user_service import BrainUserService
from quivr_api.modules.dependencies import get_supabase_client

brain_user_service = BrainUserService()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Optional
from uuid import UUID

from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.api_key.entity.api_key import ApiKey
from quivr_api.modules.api_key.repository.api_key_interface import ApiKeysInterface
from quivr_api.modules.dependencies import get_supabase_client


class ApiKeys(ApiKeysInterface):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.assistant.entity.assistant import AssistantEntity
from quivr_api.modules.assistant.repository.assistant_interface import (
AssistantInterface,
)
from quivr_api.modules.dependencies import get_supabase_client


class Assistant(AssistantInterface):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(

def get_chain(self):
list_files_array = (
self.knowledge_qa.knowledge_service.get_all_knowledge_in_brain(
await self.knowledge_qa.knowledge_service.get_all_knowledge_in_brain(
self.brain_id
)
) # pyright: ignore reportPrivateUsage=none
Expand Down
13 changes: 7 additions & 6 deletions backend/api/quivr_api/modules/brain/repository/brains.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from uuid import UUID

from sqlalchemy import text

from quivr_api.logger import get_logger
from quivr_api.models.settings import (
get_embedding_client,
get_pg_database_engine,
get_supabase_client,
)
from quivr_api.modules.brain.dto.inputs import BrainUpdatableProperties
from quivr_api.modules.brain.entity.brain_entity import BrainEntity
from quivr_api.modules.brain.repository.interfaces.brains_interface import (
BrainsInterface,
)
from sqlalchemy import text
from quivr_api.modules.dependencies import (
get_embedding_client,
get_pg_database_engine,
get_supabase_client,
)

logger = get_logger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from uuid import UUID

from quivr_api.logger import get_logger
from quivr_api.models.settings import get_embedding_client, get_supabase_client
from quivr_api.modules.brain.entity.brain_entity import (
BrainUser,
MinimalUserBrainEntity,
)
from quivr_api.modules.brain.repository.interfaces.brains_users_interface import (
BrainsUsersInterface,
)
from quivr_api.modules.dependencies import get_embedding_client, get_supabase_client

logger = get_logger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from uuid import UUID

from sqlalchemy import text

from quivr_api.logger import get_logger
from quivr_api.models.settings import get_pg_database_engine, get_supabase_client
from quivr_api.modules.brain.repository.interfaces.brains_vectors_interface import (
BrainsVectorsInterface,
)
from sqlalchemy import text
from quivr_api.modules.dependencies import get_pg_database_engine, get_supabase_client

logger = get_logger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from typing import List

from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.brain.entity.integration_brain import (
IntegrationDescriptionEntity,
IntegrationEntity,
Expand All @@ -10,10 +9,10 @@
IntegrationBrainInterface,
IntegrationDescriptionInterface,
)
from quivr_api.modules.dependencies import get_supabase_client


class Integration(ABC):

@abstractmethod
def load(self):
pass
Expand Down Expand Up @@ -63,7 +62,6 @@ def update_last_synced(self, brain_id, user_id):
return IntegrationEntity(**response.data[0])

def add_integration_brain(self, brain_id, user_id, integration_id, settings):

response = (
self.db.table("integrations_user")
.insert(
Expand Down Expand Up @@ -116,7 +114,6 @@ def get_integration_brain_by_type_integration(


class IntegrationDescription(IntegrationDescriptionInterface):

def __init__(self):
self.db = get_supabase_client()

Expand Down
5 changes: 4 additions & 1 deletion backend/api/quivr_api/modules/brain/service/brain_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
IntegrationBrain,
IntegrationDescription,
)
from quivr_api.modules.dependencies import get_service
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.vectorstore.supabase import CustomSupabaseVectorStore

logger = get_logger(__name__)
knowledge_service = get_service(KnowledgeService)()


class BrainService:
Expand Down Expand Up @@ -201,7 +204,7 @@ def update_brain_last_update_time(self, brain_id: UUID):
self.brain_repository.update_brain_last_update_time(brain_id)

def get_brain_details(
self, brain_id: UUID, user_id: UUID = None
self, brain_id: UUID, user_id: UUID | None = None
chloedia marked this conversation as resolved.
Show resolved Hide resolved
) -> BrainEntity | None:
brain = self.brain_repository.get_brain_details(brain_id)
if brain is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from quivr_api.logger import get_logger
from quivr_api.models.brains_subscription_invitations import BrainSubscription
from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.brain.service.brain_user_service import BrainUserService
from quivr_api.modules.dependencies import get_supabase_client
from quivr_api.modules.user.service.user_service import UserService

logger = get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from uuid import UUID

from attr import dataclass

from quivr_api.logger import get_logger
from quivr_api.models.settings import get_embedding_client, get_supabase_client
from quivr_api.modules.dependencies import get_embedding_client, get_supabase_client
from quivr_api.modules.upload.service.generate_file_signed_url import (
generate_file_signed_url,
)
Expand Down
6 changes: 6 additions & 0 deletions backend/api/quivr_api/modules/chat/controller/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from quivr_api.modules.user.entity.user_identity import UserIdentity
from quivr_api.utils.telemetry import maybe_send_telemetry
from quivr_api.utils.uuid_generator import generate_uuid_from_string
from quivr_api.vector.service.vector_service import VectorService

logger = get_logger(__name__)

Expand All @@ -40,6 +41,7 @@
ChatServiceDep = Annotated[ChatService, Depends(get_service(ChatService))]
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))]
VectorServiceDep = Annotated[VectorService, Depends(get_service(VectorService, False))]


def validate_authorization(user_id, brain_id):
Expand Down Expand Up @@ -167,6 +169,7 @@ async def create_question_handler(
chat_service: ChatServiceDep,
knowledge_service: KnowledgeServiceDep,
model_service: ModelServiceDep,
vector_service: VectorServiceDep,
brain_id: Annotated[UUID | None, Query()] = None,
):
models = await model_service.get_models()
Expand All @@ -192,6 +195,7 @@ async def create_question_handler(
prompt_service,
chat_service,
knowledge_service,
vector_service
)
else:
service = ChatLLMService(
Expand Down Expand Up @@ -233,6 +237,7 @@ async def create_stream_question_handler(
current_user: UserIdentityDep,
knowledge_service: KnowledgeServiceDep,
model_service: ModelServiceDep,
vector_service: VectorServiceDep,
brain_id: Annotated[UUID | None, Query()] = None,
) -> StreamingResponse:
logger.info(
Expand Down Expand Up @@ -262,6 +267,7 @@ async def create_stream_question_handler(
prompt_service,
chat_service,
knowledge_service,
vector_service
)
else:
service = ChatLLMService(
Expand Down
11 changes: 6 additions & 5 deletions backend/api/quivr_api/modules/chat/repository/chats.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Sequence
from uuid import UUID

from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.chat.dto.inputs import ChatMessageProperties, QuestionAndAnswer
from quivr_api.modules.chat.entity.chat import Chat, ChatHistory
from quivr_api.modules.dependencies import BaseRepository
from sqlalchemy import exc
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from quivr_api.modules.chat.dto.inputs import ChatMessageProperties, QuestionAndAnswer
from quivr_api.modules.chat.entity.chat import Chat, ChatHistory
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client


class ChatRepository(BaseRepository):
def __init__(self, session: AsyncSession):
Expand Down Expand Up @@ -40,7 +40,8 @@ async def get_chat_by_id(self, chat_id: UUID):

async def get_chat_history(self, chat_id: UUID) -> Sequence[ChatHistory]:
query = (
select(ChatHistory).where(ChatHistory.chat_id == chat_id)
select(ChatHistory)
.where(ChatHistory.chat_id == chat_id)
# TODO: type hints of sqlmodel arent stable for order_by
.order_by(ChatHistory.message_time) # type: ignore
)
Expand Down
Loading
Loading