Skip to content

Commit

Permalink
RAG with database routing - first initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Madhuvod committed Dec 23, 2024
1 parent 0504c5c commit 0d3e3fc
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 0 deletions.
14 changes: 14 additions & 0 deletions rag_tutorials/rag_database_routing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# RAG Database Router Demo

This demo showcases RAG (Retrieval Augmented Generation) with database routing capabilities. The application allows users to:

1. Upload documents to three different databases:
- Product Information
- Customer Support & FAQ
- Financial Information

2. Query information using natural language, with automatic routing to the most relevant database.

## Setup

1. Create a virtual environment:
252 changes: 252 additions & 0 deletions rag_tutorials/rag_database_routing/rag_database_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import os
from typing import List, Dict, Any, Literal
from dataclasses import dataclass
import streamlit as st
from dotenv import load_dotenv
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import tempfile

# Load environment variables
load_dotenv()

# Constants
DatabaseType = Literal["products", "customer_support", "financials"]
PERSIST_DIRECTORY = "db_storage"

@dataclass
class Database:
"""Class to represent a database configuration"""
name: str
description: str
collection_name: str
persist_directory: str

# Database configurations
DATABASES: Dict[DatabaseType, Database] = {
"products": Database(
name="Product Information",
description="Product details, specifications, and features",
collection_name="products_db",
persist_directory=f"{PERSIST_DIRECTORY}/products"
),
"customer_support": Database(
name="Customer Support & FAQ",
description="Customer support information, frequently asked questions, and guides",
collection_name="support_db",
persist_directory=f"{PERSIST_DIRECTORY}/support"
),
"financials": Database(
name="Financial Information",
description="Financial data, revenue, costs, and liabilities",
collection_name="finance_db",
persist_directory=f"{PERSIST_DIRECTORY}/finance"
)
}

# Router prompt template
ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and route them to the most appropriate database.
Available databases:
1. Product Information: Contains product details, specifications, and features
2. Customer Support & FAQ: Contains customer support information, frequently asked questions, and guides
3. Financial Information: Contains financial data, revenue, costs, and liabilities
User question: {question}
Return only one of these exact strings:
- products
- customer_support
- financials
Your response:"""

def init_session_state():
"""Initialize session state variables"""
if 'databases' not in st.session_state:
st.session_state.databases = {}
if 'embeddings' not in st.session_state:
st.session_state.embeddings = OpenAIEmbeddings()
if 'llm' not in st.session_state:
st.session_state.llm = ChatOpenAI(temperature=0)
if 'router_chain' not in st.session_state:
router_prompt = PromptTemplate(
template=ROUTER_TEMPLATE,
input_variables=["question"]
)
st.session_state.router_chain = LLMChain(
llm=st.session_state.llm,
prompt=router_prompt
)

def process_document(file) -> List[Document]:
"""Process uploaded PDF document"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(file.getvalue())
tmp_path = tmp_file.name

loader = PyPDFLoader(tmp_path)
documents = loader.load()

# Clean up temporary file
os.unlink(tmp_path)

text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
texts = text_splitter.split_documents(documents)

return texts
except Exception as e:
st.error(f"Error processing document: {e}")
return []

def get_or_create_db(db_type: DatabaseType) -> Chroma:
"""Get or create a database for the specified type with proper initialization and error handling"""
try:
if db_type not in st.session_state.databases:
db_config = DATABASES[db_type]

# Ensure directory exists
os.makedirs(db_config.persist_directory, exist_ok=True)

# Initialize Chroma with proper settings
st.session_state.databases[db_type] = Chroma(
persist_directory=db_config.persist_directory,
embedding_function=st.session_state.embeddings,
collection_name=db_config.collection_name,
collection_metadata={
"description": db_config.description,
"database_type": db_type
}
)

# Log successful initialization
st.success(f"Initialized {db_config.name} database")

return st.session_state.databases[db_type]

except Exception as e:
st.error(f"Error initializing {db_type} database: {str(e)}")
raise

def route_query(question: str) -> DatabaseType:
"""Route the question to the appropriate database"""
response = st.session_state.router_chain.invoke({"question": question})
return response["text"].strip().lower()

def query_database(db: Chroma, question: str) -> str:
"""Query the database and return the response"""
docs = db.similarity_search(question, k=3)

context = "\n\n".join([doc.page_content for doc in docs])

prompt = PromptTemplate(
template="""Answer the question based on the following context. If you cannot answer the question based on the context, say "I don't have enough information to answer this question."
Context: {context}
Question: {question}
Answer:""",
input_variables=["context", "question"]
)

chain = LLMChain(llm=st.session_state.llm, prompt=prompt)
response = chain.invoke({"context": context, "question": question})
return response["text"]

def clear_database(db_type: DatabaseType = None):
"""Clear specified database or all databases if none specified"""
try:
if db_type:
if db_type in st.session_state.databases:
db_config = DATABASES[db_type]
# Delete collection
st.session_state.databases[db_type]._collection.delete()
# Remove from session state
del st.session_state.databases[db_type]
# Clean up persist directory
if os.path.exists(db_config.persist_directory):
import shutil
shutil.rmtree(db_config.persist_directory)
st.success(f"Cleared {db_config.name} database")
else:
# Clear all databases
for db_type, db_config in DATABASES.items():
if db_type in st.session_state.databases:
st.session_state.databases[db_type]._collection.delete()
if os.path.exists(db_config.persist_directory):
import shutil
shutil.rmtree(db_config.persist_directory)
st.session_state.databases = {}
st.success("Cleared all databases")
except Exception as e:
st.error(f"Error clearing database(s): {str(e)}")

def main():
st.title("📚 RAG Database Router ")

init_session_state()

# Sidebar for database management
with st.sidebar:
st.header("Database Management")
if st.button("Clear All Databases"):
clear_database()

st.divider()
st.subheader("Clear Individual Databases")
for db_type, db_config in DATABASES.items():
if st.button(f"Clear {db_config.name}"):
clear_database(db_type)

# Document upload section
st.header("Document Upload")
tabs = st.tabs([db.name for db in DATABASES.values()])

for (db_type, db_config), tab in zip(DATABASES.items(), tabs):
with tab:
st.write(db_config.description)
uploaded_file = st.file_uploader(
"Upload PDF document",
type="pdf",
key=f"upload_{db_type}"
)

if uploaded_file:
with st.spinner('Processing document...'):
texts = process_document(uploaded_file)
if texts:
db = get_or_create_db(db_type)
db.add_documents(texts)
st.success("Document processed and added to the database!")

# Query section
st.header("Ask Questions")
question = st.text_input("Enter your question:")

if question:
with st.spinner('Finding answer...'):
# Route the question
db_type = route_query(question)
db = get_or_create_db(db_type)

# Display routing information
st.info(f"Routing question to: {DATABASES[db_type].name}")

# Get and display answer
answer = query_database(db, question)
st.write("### Answer")
st.write(answer)

if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions rag_tutorials/rag_database_routing/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
langchain>=0.1.0
langchain-community>=0.0.10
langchain-core>=0.1.10
chromadb>=0.4.22
streamlit>=1.29.0
python-dotenv>=1.0.0
pypdf>=4.0.0
sentence-transformers>=2.2.2
openai>=1.6.1

0 comments on commit 0d3e3fc

Please sign in to comment.