diff --git a/backend/main.py b/backend/main.py index adba9583b8..910a059d8a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -3,6 +3,7 @@ import firebase_admin from fastapi import FastAPI +import utils.retrieval.graph as graph from modal import Image, App, asgi_app, Secret, Cron from routers import workflow, chat, firmware, plugins, memories, transcribe_v2, notifications, \ diff --git a/backend/requirements.txt b/backend/requirements.txt index dcf66e04c6..54f1ef2ab6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,7 +1,7 @@ aenum==3.1.15 aiofiles==24.1.0 aiohappyeyeballs==2.3.4 -aiohttp==3.10.1 +aiohttp==3.9.5 aiosignal==1.3.1 aiostream==0.5.2 alembic==1.13.2 @@ -61,6 +61,7 @@ google-cloud-storage==2.18.0 google-crc32c==1.5.0 google-resumable-media==2.7.1 googleapis-common-protos==1.63.2 +graphviz==0.20.3 groq==0.9.0 grpcio==1.65.4 grpcio-status==1.62.3 @@ -82,6 +83,7 @@ idna==3.7 ipython==8.26.0 jedi==0.19.1 Jinja2==3.1.4 +jiter==0.6.1 jiwer==3.0.4 joblib==1.4.2 jsonpatch==1.33 @@ -90,14 +92,17 @@ jsonschema==4.23.0 jsonschema-specifications==2023.12.1 julius==0.2.7 kiwisolver==1.4.5 -langchain==0.2.12 -langchain-community==0.2.11 -langchain-core==0.2.28 +langchain==0.3.4 +langchain-community==0.3.3 +langchain-core==0.3.12 langchain-groq==0.1.9 -langchain-openai==0.1.20 -langchain-pinecone==0.1.3 -langchain-text-splitters==0.2.2 -langsmith==0.1.96 +langchain-openai==0.2.3 +langchain-pinecone==0.2.0 +langchain-text-splitters==0.3.0 +langgraph==0.2.39 +langgraph-checkpoint==2.0.1 +langgraph-sdk==0.1.33 +langsmith==0.1.137 lazy_loader==0.4 librosa==0.10.2.post1 lightning==2.4.0 @@ -116,7 +121,7 @@ more-itertools==10.5.0 mplcursors==0.5.3 mpld3==0.5.10 mpmath==1.3.0 -msgpack==1.0.8 +msgpack==1.1.0 multidict==6.0.5 mypy-extensions==1.0.0 narwhals==1.5.2 @@ -127,7 +132,7 @@ numba==0.60.0 numpy==1.26.4 omegaconf==2.3.0 onnxruntime==1.19.0 -openai==1.39.0 +openai==1.52.2 optuna==3.6.1 opuslib==3.0.1 orjson==3.10.6 @@ -161,13 +166,15 @@ pyasn1==0.6.0 pyasn1_modules==0.4.0 pycparser==2.22 pydantic==2.8.2 +pydantic-settings==2.6.0 pydantic_core==2.20.1 pydeck==0.9.1 pydub==0.25.1 Pygments==2.18.0 +pygraphviz==1.14 PyJWT==2.9.0 pynndescent==0.5.13 -PyOgg@ git+https://github.com/TeamPyOgg/PyOgg@6871a4f234e8a3a346c4874a12509bfa02c4c63a +PyOgg @ git+https://github.com/TeamPyOgg/PyOgg@6871a4f234e8a3a346c4874a12509bfa02c4c63a pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 @@ -182,6 +189,7 @@ redis==5.0.8 referencing==0.35.1 regex==2024.7.24 requests==2.32.3 +requests-toolbelt==1.0.0 rich==13.7.1 rpds-py==0.20.0 rsa==4.9 @@ -241,4 +249,4 @@ watchfiles==0.22.0 wcwidth==0.2.13 webrtcvad==2.0.10 websockets==12.0 -yarl==1.9.4 \ No newline at end of file +yarl==1.9.4 diff --git a/backend/scripts/users/retrieval.py b/backend/scripts/users/retrieval.py index 5dd895ad68..b527b455df 100644 --- a/backend/scripts/users/retrieval.py +++ b/backend/scripts/users/retrieval.py @@ -44,7 +44,8 @@ class ExtractedInformation(BaseModel): ) -def migrate_memory_vector_metadata(memory_id: str, created_at: datetime, transcript_segment: List[dict]) -> ExtractedInformation: +def migrate_memory_vector_metadata(memory_id: str, created_at: datetime, + transcript_segment: List[dict]) -> ExtractedInformation: transcript = '' for segment in transcript_segment: transcript += f'{segment["text"].strip()}\n\n' @@ -82,6 +83,8 @@ def migrate_memory_vector_metadata(memory_id: str, created_at: datetime, transcr if __name__ == '__main__': + # TODO: finish migration script + # TODO: inlude process_memory to process this too uids = get_users_uid() for uid in ['TtCJi59JTVXHmyUC6vUQ1d9U6cK2']: memories = memories_db.get_memories(uid, limit=1000) diff --git a/backend/utils/retrieval/graph.py b/backend/utils/retrieval/graph.py new file mode 100644 index 0000000000..d476bdc1c1 --- /dev/null +++ b/backend/utils/retrieval/graph.py @@ -0,0 +1,87 @@ +from typing import Literal + +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.constants import END +from langgraph.graph import START, StateGraph, MessagesState + +model = ChatOpenAI(model='gpt-4o-mini') + + +def determine_conversation_type(s: MessagesState) -> Literal["no_context_conversation", "context_dependent_conversation"]: + # call requires context + # if requires context, spawn 2 parallel graphs edges? + return 'no_context_conversation' + + +def no_context_conversation(state: MessagesState): + # continue the conversation + return END + + +def context_dependent_conversation(state: MessagesState): + pass + + +# TODO: include a question extractor? node? + + +def retrieve_topics_filters(state: MessagesState): + # retrieve all available entities, names, topics, etc, and ask it to filter based on the question. + return 'query_vectors' + + +def retrieve_date_filters(state: MessagesState): + # extract dates filters, and send them to qa_handler node + return 'query_vectors' + + +def query_vectors(state: MessagesState): + # receives both filters, and finds vectors + rerank them + # TODO: maybe didnt find anything, tries RAG, or goes to simple conversation? + pass + + +def qa_handler(state: MessagesState): + # takes vectors found, retrieves memories, and does QA on them + return END + + +workflow = StateGraph(MessagesState) # custom state? + + + + +workflow.add_edge(START, "determine_conversation_type") +workflow.add_node('determine_conversation_type', determine_conversation_type) + +workflow.add_conditional_edges( + "determine_conversation_type", + determine_conversation_type, +) + +workflow.add_node("no_context_conversation", no_context_conversation) +workflow.add_node("context_dependent_conversation", context_dependent_conversation) + +workflow.add_edge("context_dependent_conversation", "retrieve_topics_filters") +workflow.add_edge("context_dependent_conversation", "retrieve_date_filters") + +workflow.add_node("retrieve_topics_filters", retrieve_topics_filters) +workflow.add_node("retrieve_date_filters", retrieve_date_filters) + +workflow.add_edge('retrieve_topics_filters', 'query_vectors') +workflow.add_edge('retrieve_date_filters', 'query_vectors') + +workflow.add_node('query_vectors', query_vectors) + +workflow.add_edge('query_vectors', 'qa_handler') + +workflow.add_node('qa_handler', qa_handler) + +workflow.add_edge('qa_handler', END) + +checkpointer = MemorySaver() +graph = workflow.compile(checkpointer=checkpointer) + +if __name__ == '__main__': + graph.get_graph().draw_png('workflow.png')