-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RAG with database routing - first initialization
- Loading branch information
Showing
3 changed files
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# RAG Database Router Demo | ||
|
||
This demo showcases RAG (Retrieval Augmented Generation) with database routing capabilities. The application allows users to: | ||
|
||
1. Upload documents to three different databases: | ||
- Product Information | ||
- Customer Support & FAQ | ||
- Financial Information | ||
|
||
2. Query information using natural language, with automatic routing to the most relevant database. | ||
|
||
## Setup | ||
|
||
1. Create a virtual environment: |
252 changes: 252 additions & 0 deletions
252
rag_tutorials/rag_database_routing/rag_database_routing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
import os | ||
from typing import List, Dict, Any, Literal | ||
from dataclasses import dataclass | ||
import streamlit as st | ||
from dotenv import load_dotenv | ||
from langchain_core.documents import Document | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain_community.document_loaders import PyPDFLoader | ||
from langchain_community.vectorstores import Chroma | ||
from langchain_community.embeddings import OpenAIEmbeddings | ||
from langchain_openai import ChatOpenAI | ||
from langchain.chains import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
import tempfile | ||
|
||
# Load environment variables | ||
load_dotenv() | ||
|
||
# Constants | ||
DatabaseType = Literal["products", "customer_support", "financials"] | ||
PERSIST_DIRECTORY = "db_storage" | ||
|
||
@dataclass | ||
class Database: | ||
"""Class to represent a database configuration""" | ||
name: str | ||
description: str | ||
collection_name: str | ||
persist_directory: str | ||
|
||
# Database configurations | ||
DATABASES: Dict[DatabaseType, Database] = { | ||
"products": Database( | ||
name="Product Information", | ||
description="Product details, specifications, and features", | ||
collection_name="products_db", | ||
persist_directory=f"{PERSIST_DIRECTORY}/products" | ||
), | ||
"customer_support": Database( | ||
name="Customer Support & FAQ", | ||
description="Customer support information, frequently asked questions, and guides", | ||
collection_name="support_db", | ||
persist_directory=f"{PERSIST_DIRECTORY}/support" | ||
), | ||
"financials": Database( | ||
name="Financial Information", | ||
description="Financial data, revenue, costs, and liabilities", | ||
collection_name="finance_db", | ||
persist_directory=f"{PERSIST_DIRECTORY}/finance" | ||
) | ||
} | ||
|
||
# Router prompt template | ||
ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and route them to the most appropriate database. | ||
Available databases: | ||
1. Product Information: Contains product details, specifications, and features | ||
2. Customer Support & FAQ: Contains customer support information, frequently asked questions, and guides | ||
3. Financial Information: Contains financial data, revenue, costs, and liabilities | ||
User question: {question} | ||
Return only one of these exact strings: | ||
- products | ||
- customer_support | ||
- financials | ||
Your response:""" | ||
|
||
def init_session_state(): | ||
"""Initialize session state variables""" | ||
if 'databases' not in st.session_state: | ||
st.session_state.databases = {} | ||
if 'embeddings' not in st.session_state: | ||
st.session_state.embeddings = OpenAIEmbeddings() | ||
if 'llm' not in st.session_state: | ||
st.session_state.llm = ChatOpenAI(temperature=0) | ||
if 'router_chain' not in st.session_state: | ||
router_prompt = PromptTemplate( | ||
template=ROUTER_TEMPLATE, | ||
input_variables=["question"] | ||
) | ||
st.session_state.router_chain = LLMChain( | ||
llm=st.session_state.llm, | ||
prompt=router_prompt | ||
) | ||
|
||
def process_document(file) -> List[Document]: | ||
"""Process uploaded PDF document""" | ||
try: | ||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | ||
tmp_file.write(file.getvalue()) | ||
tmp_path = tmp_file.name | ||
|
||
loader = PyPDFLoader(tmp_path) | ||
documents = loader.load() | ||
|
||
# Clean up temporary file | ||
os.unlink(tmp_path) | ||
|
||
text_splitter = RecursiveCharacterTextSplitter( | ||
chunk_size=1000, | ||
chunk_overlap=200 | ||
) | ||
texts = text_splitter.split_documents(documents) | ||
|
||
return texts | ||
except Exception as e: | ||
st.error(f"Error processing document: {e}") | ||
return [] | ||
|
||
def get_or_create_db(db_type: DatabaseType) -> Chroma: | ||
"""Get or create a database for the specified type with proper initialization and error handling""" | ||
try: | ||
if db_type not in st.session_state.databases: | ||
db_config = DATABASES[db_type] | ||
|
||
# Ensure directory exists | ||
os.makedirs(db_config.persist_directory, exist_ok=True) | ||
|
||
# Initialize Chroma with proper settings | ||
st.session_state.databases[db_type] = Chroma( | ||
persist_directory=db_config.persist_directory, | ||
embedding_function=st.session_state.embeddings, | ||
collection_name=db_config.collection_name, | ||
collection_metadata={ | ||
"description": db_config.description, | ||
"database_type": db_type | ||
} | ||
) | ||
|
||
# Log successful initialization | ||
st.success(f"Initialized {db_config.name} database") | ||
|
||
return st.session_state.databases[db_type] | ||
|
||
except Exception as e: | ||
st.error(f"Error initializing {db_type} database: {str(e)}") | ||
raise | ||
|
||
def route_query(question: str) -> DatabaseType: | ||
"""Route the question to the appropriate database""" | ||
response = st.session_state.router_chain.invoke({"question": question}) | ||
return response["text"].strip().lower() | ||
|
||
def query_database(db: Chroma, question: str) -> str: | ||
"""Query the database and return the response""" | ||
docs = db.similarity_search(question, k=3) | ||
|
||
context = "\n\n".join([doc.page_content for doc in docs]) | ||
|
||
prompt = PromptTemplate( | ||
template="""Answer the question based on the following context. If you cannot answer the question based on the context, say "I don't have enough information to answer this question." | ||
Context: {context} | ||
Question: {question} | ||
Answer:""", | ||
input_variables=["context", "question"] | ||
) | ||
|
||
chain = LLMChain(llm=st.session_state.llm, prompt=prompt) | ||
response = chain.invoke({"context": context, "question": question}) | ||
return response["text"] | ||
|
||
def clear_database(db_type: DatabaseType = None): | ||
"""Clear specified database or all databases if none specified""" | ||
try: | ||
if db_type: | ||
if db_type in st.session_state.databases: | ||
db_config = DATABASES[db_type] | ||
# Delete collection | ||
st.session_state.databases[db_type]._collection.delete() | ||
# Remove from session state | ||
del st.session_state.databases[db_type] | ||
# Clean up persist directory | ||
if os.path.exists(db_config.persist_directory): | ||
import shutil | ||
shutil.rmtree(db_config.persist_directory) | ||
st.success(f"Cleared {db_config.name} database") | ||
else: | ||
# Clear all databases | ||
for db_type, db_config in DATABASES.items(): | ||
if db_type in st.session_state.databases: | ||
st.session_state.databases[db_type]._collection.delete() | ||
if os.path.exists(db_config.persist_directory): | ||
import shutil | ||
shutil.rmtree(db_config.persist_directory) | ||
st.session_state.databases = {} | ||
st.success("Cleared all databases") | ||
except Exception as e: | ||
st.error(f"Error clearing database(s): {str(e)}") | ||
|
||
def main(): | ||
st.title("📚 RAG Database Router ") | ||
|
||
init_session_state() | ||
|
||
# Sidebar for database management | ||
with st.sidebar: | ||
st.header("Database Management") | ||
if st.button("Clear All Databases"): | ||
clear_database() | ||
|
||
st.divider() | ||
st.subheader("Clear Individual Databases") | ||
for db_type, db_config in DATABASES.items(): | ||
if st.button(f"Clear {db_config.name}"): | ||
clear_database(db_type) | ||
|
||
# Document upload section | ||
st.header("Document Upload") | ||
tabs = st.tabs([db.name for db in DATABASES.values()]) | ||
|
||
for (db_type, db_config), tab in zip(DATABASES.items(), tabs): | ||
with tab: | ||
st.write(db_config.description) | ||
uploaded_file = st.file_uploader( | ||
"Upload PDF document", | ||
type="pdf", | ||
key=f"upload_{db_type}" | ||
) | ||
|
||
if uploaded_file: | ||
with st.spinner('Processing document...'): | ||
texts = process_document(uploaded_file) | ||
if texts: | ||
db = get_or_create_db(db_type) | ||
db.add_documents(texts) | ||
st.success("Document processed and added to the database!") | ||
|
||
# Query section | ||
st.header("Ask Questions") | ||
question = st.text_input("Enter your question:") | ||
|
||
if question: | ||
with st.spinner('Finding answer...'): | ||
# Route the question | ||
db_type = route_query(question) | ||
db = get_or_create_db(db_type) | ||
|
||
# Display routing information | ||
st.info(f"Routing question to: {DATABASES[db_type].name}") | ||
|
||
# Get and display answer | ||
answer = query_database(db, question) | ||
st.write("### Answer") | ||
st.write(answer) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
langchain>=0.1.0 | ||
langchain-community>=0.0.10 | ||
langchain-core>=0.1.10 | ||
chromadb>=0.4.22 | ||
streamlit>=1.29.0 | ||
python-dotenv>=1.0.0 | ||
pypdf>=4.0.0 | ||
sentence-transformers>=2.2.2 | ||
openai>=1.6.1 |