Skip to content

Commit

Permalink
initial langgraph graph setup
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Oct 24, 2024
1 parent 790c22f commit f713d4b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 13 deletions.
1 change: 1 addition & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import firebase_admin
from fastapi import FastAPI
import utils.retrieval.graph as graph

from modal import Image, App, asgi_app, Secret, Cron
from routers import workflow, chat, firmware, plugins, memories, transcribe_v2, notifications, \
Expand Down
32 changes: 20 additions & 12 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aenum==3.1.15
aiofiles==24.1.0
aiohappyeyeballs==2.3.4
aiohttp==3.10.1
aiohttp==3.9.5
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.2
Expand Down Expand Up @@ -61,6 +61,7 @@ google-cloud-storage==2.18.0
google-crc32c==1.5.0
google-resumable-media==2.7.1
googleapis-common-protos==1.63.2
graphviz==0.20.3
groq==0.9.0
grpcio==1.65.4
grpcio-status==1.62.3
Expand All @@ -82,6 +83,7 @@ idna==3.7
ipython==8.26.0
jedi==0.19.1
Jinja2==3.1.4
jiter==0.6.1
jiwer==3.0.4
joblib==1.4.2
jsonpatch==1.33
Expand All @@ -90,14 +92,17 @@ jsonschema==4.23.0
jsonschema-specifications==2023.12.1
julius==0.2.7
kiwisolver==1.4.5
langchain==0.2.12
langchain-community==0.2.11
langchain-core==0.2.28
langchain==0.3.4
langchain-community==0.3.3
langchain-core==0.3.12
langchain-groq==0.1.9
langchain-openai==0.1.20
langchain-pinecone==0.1.3
langchain-text-splitters==0.2.2
langsmith==0.1.96
langchain-openai==0.2.3
langchain-pinecone==0.2.0
langchain-text-splitters==0.3.0
langgraph==0.2.39
langgraph-checkpoint==2.0.1
langgraph-sdk==0.1.33
langsmith==0.1.137
lazy_loader==0.4
librosa==0.10.2.post1
lightning==2.4.0
Expand All @@ -116,7 +121,7 @@ more-itertools==10.5.0
mplcursors==0.5.3
mpld3==0.5.10
mpmath==1.3.0
msgpack==1.0.8
msgpack==1.1.0
multidict==6.0.5
mypy-extensions==1.0.0
narwhals==1.5.2
Expand All @@ -127,7 +132,7 @@ numba==0.60.0
numpy==1.26.4
omegaconf==2.3.0
onnxruntime==1.19.0
openai==1.39.0
openai==1.52.2
optuna==3.6.1
opuslib==3.0.1
orjson==3.10.6
Expand Down Expand Up @@ -161,13 +166,15 @@ pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==2.8.2
pydantic-settings==2.6.0
pydantic_core==2.20.1
pydeck==0.9.1
pydub==0.25.1
Pygments==2.18.0
pygraphviz==1.14
PyJWT==2.9.0
pynndescent==0.5.13
PyOgg@ git+https://github.com/TeamPyOgg/PyOgg@6871a4f234e8a3a346c4874a12509bfa02c4c63a
PyOgg @ git+https://github.com/TeamPyOgg/PyOgg@6871a4f234e8a3a346c4874a12509bfa02c4c63a
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
Expand All @@ -182,6 +189,7 @@ redis==5.0.8
referencing==0.35.1
regex==2024.7.24
requests==2.32.3
requests-toolbelt==1.0.0
rich==13.7.1
rpds-py==0.20.0
rsa==4.9
Expand Down Expand Up @@ -241,4 +249,4 @@ watchfiles==0.22.0
wcwidth==0.2.13
webrtcvad==2.0.10
websockets==12.0
yarl==1.9.4
yarl==1.9.4
5 changes: 4 additions & 1 deletion backend/scripts/users/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class ExtractedInformation(BaseModel):
)


def migrate_memory_vector_metadata(memory_id: str, created_at: datetime, transcript_segment: List[dict]) -> ExtractedInformation:
def migrate_memory_vector_metadata(memory_id: str, created_at: datetime,
transcript_segment: List[dict]) -> ExtractedInformation:
transcript = ''
for segment in transcript_segment:
transcript += f'{segment["text"].strip()}\n\n'
Expand Down Expand Up @@ -82,6 +83,8 @@ def migrate_memory_vector_metadata(memory_id: str, created_at: datetime, transcr


if __name__ == '__main__':
# TODO: finish migration script
# TODO: inlude process_memory to process this too
uids = get_users_uid()
for uid in ['TtCJi59JTVXHmyUC6vUQ1d9U6cK2']:
memories = memories_db.get_memories(uid, limit=1000)
Expand Down
87 changes: 87 additions & 0 deletions backend/utils/retrieval/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Literal

from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.constants import END
from langgraph.graph import START, StateGraph, MessagesState

model = ChatOpenAI(model='gpt-4o-mini')


def determine_conversation_type(s: MessagesState) -> Literal["no_context_conversation", "context_dependent_conversation"]:
# call requires context
# if requires context, spawn 2 parallel graphs edges?
return 'no_context_conversation'


def no_context_conversation(state: MessagesState):
# continue the conversation
return END


def context_dependent_conversation(state: MessagesState):
pass


# TODO: include a question extractor? node?


def retrieve_topics_filters(state: MessagesState):
# retrieve all available entities, names, topics, etc, and ask it to filter based on the question.
return 'query_vectors'


def retrieve_date_filters(state: MessagesState):
# extract dates filters, and send them to qa_handler node
return 'query_vectors'


def query_vectors(state: MessagesState):
# receives both filters, and finds vectors + rerank them
# TODO: maybe didnt find anything, tries RAG, or goes to simple conversation?
pass


def qa_handler(state: MessagesState):
# takes vectors found, retrieves memories, and does QA on them
return END


workflow = StateGraph(MessagesState) # custom state?




workflow.add_edge(START, "determine_conversation_type")
workflow.add_node('determine_conversation_type', determine_conversation_type)

workflow.add_conditional_edges(
"determine_conversation_type",
determine_conversation_type,
)

workflow.add_node("no_context_conversation", no_context_conversation)
workflow.add_node("context_dependent_conversation", context_dependent_conversation)

workflow.add_edge("context_dependent_conversation", "retrieve_topics_filters")
workflow.add_edge("context_dependent_conversation", "retrieve_date_filters")

workflow.add_node("retrieve_topics_filters", retrieve_topics_filters)
workflow.add_node("retrieve_date_filters", retrieve_date_filters)

workflow.add_edge('retrieve_topics_filters', 'query_vectors')
workflow.add_edge('retrieve_date_filters', 'query_vectors')

workflow.add_node('query_vectors', query_vectors)

workflow.add_edge('query_vectors', 'qa_handler')

workflow.add_node('qa_handler', qa_handler)

workflow.add_edge('qa_handler', END)

checkpointer = MemorySaver()
graph = workflow.compile(checkpointer=checkpointer)

if __name__ == '__main__':
graph.get_graph().draw_png('workflow.png')

0 comments on commit f713d4b

Please sign in to comment.