Skip to content

Commit

Permalink
final changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Madhuvod committed Dec 25, 2024
1 parent 7035e9e commit d0c0798
Showing 1 changed file with 112 additions and 29 deletions.
141 changes: 112 additions & 29 deletions rag_tutorials/rag_database_routing/rag_database_routing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from typing import List, Dict, Literal
from typing import List, Dict, Any, Literal
from dataclasses import dataclass
import streamlit as st
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
import tempfile
from phi.agent import Agent
from phi.model.openai import OpenAIChat
Expand All @@ -17,8 +18,10 @@
from langgraph.prebuilt import create_react_agent
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.language_models import BaseLanguageModel
from langchain.prompts import ChatPromptTemplate

def init_session_state():
"""Initialize session state variables"""
if 'openai_api_key' not in st.session_state:
st.session_state.openai_api_key = ""
if 'embeddings' not in st.session_state:
Expand All @@ -40,6 +43,7 @@ class CollectionConfig:
collection_name: str
persist_directory: str

# Collection configurations
COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
"products": CollectionConfig(
name="Product Information",
Expand All @@ -62,50 +66,77 @@ class CollectionConfig:
}

def initialize_models():
"""Initialize OpenAI models with API key"""
if st.session_state.openai_api_key:
os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
st.session_state.llm = ChatOpenAI(temperature=0)

for config in COLLECTIONS.values():
os.makedirs(config.persist_directory, exist_ok=True)
# Ensure directories exist
for collection_config in COLLECTIONS.values():
os.makedirs(collection_config.persist_directory, exist_ok=True)

# Initialize Chroma collections
st.session_state.databases = {
db_type: Chroma(
collection_name=config.collection_name,
"products": Chroma(
collection_name=COLLECTIONS["products"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=config.persist_directory
) for db_type, config in COLLECTIONS.items()
persist_directory=COLLECTIONS["products"].persist_directory
),
"support": Chroma(
collection_name=COLLECTIONS["support"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=COLLECTIONS["support"].persist_directory
),
"finance": Chroma(
collection_name=COLLECTIONS["finance"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=COLLECTIONS["finance"].persist_directory
)
}
return True
return False

def process_document(file) -> List[Document]:
"""Process uploaded PDF document"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(file.getvalue())
tmp_path = tmp_file.name

loader = PyPDFLoader(tmp_path)
documents = loader.load()

# Clean up temporary file
os.unlink(tmp_path)

text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300)
return text_splitter.split_documents(documents)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
texts = text_splitter.split_documents(documents)

return texts
except Exception as e:
st.error(f"Error processing document: {e}")
return []

def create_routing_agent() -> Agent:
"""Creates a routing agent using phidata framework"""
return Agent(
model=OpenAIChat(id="gpt-4o", api_key=st.session_state.openai_api_key),
model=OpenAIChat(
id="gpt-4o",
api_key=st.session_state.openai_api_key
),
tools=[],
description="You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to.",
description="""You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to.
You must respond with exactly one of these three options: 'products', 'support', or 'finance'. The user's question is: {question}""",
instructions=[
"1. For questions about products, return 'products'",
"2. For questions about support, return 'support'",
"3. For questions about finance, return 'finance'",
"4. Return ONLY the database name"
"Follow these rules strictly:",
"1. For questions about products, features, specifications, or item details, or product manuals → return 'products'",
"2. For questions about help, guidance, troubleshooting, or customer service, FAQ, or guides → return 'support'",
"3. For questions about costs, revenue, pricing, or financial data, or financial reports and investments → return 'finance'",
"4. Return ONLY the database name, no other text or explanation"
],
markdown=False,
show_tool_calls=False
Expand All @@ -115,42 +146,72 @@ def route_query(question: str) -> DatabaseType:
try:
routing_agent = create_routing_agent()
response = routing_agent.run(question)
db_type = response.content.strip().lower().translate(str.maketrans('', '', '`\'"'))

db_type = (response.content
.strip()
.lower()
.translate(str.maketrans('', '', '`\'"'))) # More elegant string cleaning

# Validate database type
if db_type not in COLLECTIONS:
st.warning(f"Invalid database type: {db_type}, defaulting to products")
return "products"

st.info(f"Routing question to {db_type} database")
return db_type

except Exception as e:
st.error(f"Routing error: {str(e)}")
return "products"

def create_fallback_agent(chat_model: BaseLanguageModel):
"""Create a LangGraph agent for web research."""

def web_research(query: str) -> str:
"""Web search with result formatting."""
try:
search = DuckDuckGoSearchRun(num_results=5)
return search.run(query)
results = search.run(query)
return results
except Exception as e:
return f"Search failed: {str(e)}. Providing answer based on general knowledge."

tools = [web_research]
return create_react_agent(model=chat_model, tools=tools, debug=False)

agent = create_react_agent(model=chat_model,
tools=tools,
debug=False)

return agent

def query_database(db: Chroma, question: str) -> tuple[str, list]:
try:
retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"k": 4, "score_threshold": 0.4})
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 4, "score_threshold": 0.3}
)

relevant_docs = retriever.get_relevant_documents(question)

if relevant_docs:
retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
# Use simpler chain creation with hub prompt
retrieval_qa_prompt = ChatPromptTemplate.from_messages([
("system", """You are a helpful AI assistant that answers questions based on provided context.
Always be direct and concise in your responses.
If the context doesn't contain enough information to fully answer the question, acknowledge this limitation.
Base your answers strictly on the provided context and avoid making assumptions."""),
("human", "Here is the context:\n{context}"),
("human", "Question: {input}"),
("assistant", "I'll help answer your question based on the context provided."),
("human", "Please provide your answer:"),
])
combine_docs_chain = create_stuff_documents_chain(st.session_state.llm, retrieval_qa_prompt)
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)

response = retrieval_chain.invoke({"input": question})
return response['answer'], relevant_docs

return _handle_web_fallback(question)

except Exception as e:
st.error(f"Error: {str(e)}")
return "I encountered an error. Please try rephrasing your question.", []
Expand All @@ -161,7 +222,9 @@ def _handle_web_fallback(question: str) -> tuple[str, list]:

with st.spinner('Researching...'):
agent_input = {
"messages": [HumanMessage(content=f"Research and provide a detailed answer for: '{question}'")],
"messages": [
HumanMessage(content=f"Research and provide a detailed answer for: '{question}'")
],
"is_last_step": False
}

Expand All @@ -170,17 +233,26 @@ def _handle_web_fallback(question: str) -> tuple[str, list]:
if isinstance(response, dict) and "messages" in response:
answer = response["messages"][-1].content
return f"Web Search Result:\n{answer}", []

except Exception:
# Fallback to general LLM response
fallback_response = st.session_state.llm.invoke(question).content
return f"Web search unavailable. General response: {fallback_response}", []

def main():
"""Main application function."""
st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚")
st.title("📚 RAG Agent with Database Routing")

# Sidebar for API key and database management
with st.sidebar:
st.header("Configuration")
api_key = st.text_input("Enter OpenAI API Key:", type="password", value=st.session_state.openai_api_key, key="api_key_input")
api_key = st.text_input(
"Enter OpenAI API Key:",
type="password",
value=st.session_state.openai_api_key,
key="api_key_input"
)

if api_key:
st.session_state.openai_api_key = api_key
Expand All @@ -194,15 +266,20 @@ def main():
st.stop()

st.markdown("---")

st.header("Document Upload")
st.info("Upload documents to populate the databases. Each tab corresponds to a different database.")
tabs = st.tabs([config.name for config in COLLECTIONS.values()])
tabs = st.tabs([collection_config.name for collection_config in COLLECTIONS.values()])

for (collection_type, config), tab in zip(COLLECTIONS.items(), tabs):
for (collection_type, collection_config), tab in zip(COLLECTIONS.items(), tabs):
with tab:
st.write(config.description)
uploaded_files = st.file_uploader(f"Upload PDF documents to {config.name}", type="pdf", key=f"upload_{collection_type}", accept_multiple_files=True)
st.write(collection_config.description)
uploaded_files = st.file_uploader(
f"Upload PDF documents to {collection_config.name}",
type="pdf",
key=f"upload_{collection_type}",
accept_multiple_files=True
)

if uploaded_files:
with st.spinner('Processing documents...'):
Expand All @@ -216,15 +293,21 @@ def main():
db.add_documents(all_texts)
st.success("Documents processed and added to the database!")

# Query section
st.header("Ask Questions")
st.info("Enter your question below to find answers from the relevant database.")
question = st.text_input("Enter your question:")

if question:
with st.spinner('Finding answer...'):
# Route the question
collection_type = route_query(question)
db = st.session_state.databases[collection_type]

# Display routing information
st.info(f"Routing question to: {COLLECTIONS[collection_type].name}")

# Get and display answer
answer, relevant_docs = query_database(db, question)
st.write("### Answer")
st.write(answer)
Expand Down

0 comments on commit d0c0798

Please sign in to comment.