Skip to content

Commit

Permalink
add: LangfuseService class
Browse files Browse the repository at this point in the history
  • Loading branch information
chloedia committed Dec 23, 2024
1 parent dbc2146 commit 2d7be22
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
5 changes: 3 additions & 2 deletions core/quivr_core/rag/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langfuse.callback import CallbackHandler

from quivr_core.llm import LLMEndpoint
from quivr_core.rag.entities.chat import ChatHistory
Expand All @@ -25,6 +24,7 @@
)
from quivr_core.rag.prompts import custom_prompts
from quivr_core.rag.utils import (
LangfuseService,
combine_documents,
format_file_list,
get_chunk_metadata,
Expand All @@ -33,7 +33,8 @@
)

logger = logging.getLogger("quivr_core")
langfuse_handler = CallbackHandler()
langfuse_service = LangfuseService()
langfuse_handler = langfuse_service.get_handler()


class IdempotentCompressor(BaseDocumentCompressor):
Expand Down
17 changes: 8 additions & 9 deletions core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from langgraph.types import Send
from pydantic import BaseModel, Field

from langfuse.callback import CallbackHandler

from quivr_core.llm import LLMEndpoint
from quivr_core.llm_tools.llm_tools import LLMToolFactory
from quivr_core.rag.entities.chat import ChatHistory
Expand All @@ -41,6 +39,7 @@
)
from quivr_core.rag.prompts import custom_prompts
from quivr_core.rag.utils import (
LangfuseService,
collect_tools,
combine_documents,
format_file_list,
Expand All @@ -50,8 +49,8 @@

logger = logging.getLogger("quivr_core")

# Initialize Langfuse CallbackHandler for Langchain (tracing)
langfuse_handler = CallbackHandler()
langfuse_service = LangfuseService()
langfuse_handler = langfuse_service.get_handler()


class SplittedInput(BaseModel):
Expand Down Expand Up @@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

# Replace each question with its condensed version
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_definition(task_id, response.content)

return {**state, "tasks": tasks}
Expand Down Expand Up @@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState):
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_completion(task_id, response.is_task_completable)
if not response.is_task_completable and response.tool:
tasks.set_tool(task_id, response.tool)
Expand Down Expand Up @@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState:
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = tool_wrapper.format_output(response)
_docs = self.filter_chunks_by_relevance(_docs)
tasks.set_docs(task_id, _docs)
Expand Down Expand Up @@ -652,7 +651,7 @@ async def retrieve(self, state: AgentState) -> AgentState:
task_ids = [task[1] for task in async_jobs] if async_jobs else []

# Process responses and associate docs with tasks
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
tasks.set_docs(task_id, _docs) # Associate docs with the specific task

Expand Down Expand Up @@ -715,7 +714,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

_n = []
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
_n.append(len(_docs))
tasks.set_docs(task_id, _docs)
Expand Down
9 changes: 9 additions & 0 deletions core/quivr_core/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document
from langfuse.callback import CallbackHandler

from quivr_core.rag.entities.config import WorkflowConfig
from quivr_core.rag.entities.models import (
Expand Down Expand Up @@ -195,3 +196,11 @@ def collect_tools(workflow_config: WorkflowConfig):
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"

return validated_tools, activated_tools


class LangfuseService:
def __init__(self):
self.langfuse_handler = CallbackHandler()

def get_handler(self):
return self.langfuse_handler

0 comments on commit 2d7be22

Please sign in to comment.