Skip to content

Commit

Permalink
added qdrant as db
Browse files Browse the repository at this point in the history
  • Loading branch information
Madhuvod committed Dec 25, 2024
1 parent d0c0798 commit f9c755d
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 65 deletions.
200 changes: 136 additions & 64 deletions rag_tutorials/rag_database_routing/rag_database_routing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from typing import List, Dict, Any, Literal
from typing import List, Dict, Any, Literal, Optional
from dataclasses import dataclass
import streamlit as st
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_community.vectorstores import Qdrant
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
import tempfile
Expand All @@ -19,11 +19,17 @@
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.language_models import BaseLanguageModel
from langchain.prompts import ChatPromptTemplate
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

def init_session_state():
"""Initialize session state variables"""
if 'openai_api_key' not in st.session_state:
st.session_state.openai_api_key = ""
if 'qdrant_url' not in st.session_state:
st.session_state.qdrant_url = ""
if 'qdrant_api_key' not in st.session_state:
st.session_state.qdrant_api_key = ""
if 'embeddings' not in st.session_state:
st.session_state.embeddings = None
if 'llm' not in st.session_state:
Expand All @@ -40,61 +46,68 @@ def init_session_state():
class CollectionConfig:
name: str
description: str
collection_name: str
persist_directory: str
collection_name: str # This will be used as Qdrant collection name

# Collection configurations
COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
"products": CollectionConfig(
name="Product Information",
description="Product details, specifications, and features",
collection_name="products_collection",
persist_directory=f"{PERSIST_DIRECTORY}/products"
collection_name="products_collection"
),
"support": CollectionConfig(
name="Customer Support & FAQ",
description="Customer support information, frequently asked questions, and guides",
collection_name="support_collection",
persist_directory=f"{PERSIST_DIRECTORY}/support"
collection_name="support_collection"
),
"finance": CollectionConfig(
name="Financial Information",
description="Financial data, revenue, costs, and liabilities",
collection_name="finance_collection",
persist_directory=f"{PERSIST_DIRECTORY}/finance"
collection_name="finance_collection"
)
}

def initialize_models():
"""Initialize OpenAI models with API key"""
if st.session_state.openai_api_key:
"""Initialize OpenAI models and Qdrant client"""
if (st.session_state.openai_api_key and
st.session_state.qdrant_url and
st.session_state.qdrant_api_key):

os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
st.session_state.llm = ChatOpenAI(temperature=0)

# Ensure directories exist
for collection_config in COLLECTIONS.values():
os.makedirs(collection_config.persist_directory, exist_ok=True)

# Initialize Chroma collections
st.session_state.databases = {
"products": Chroma(
collection_name=COLLECTIONS["products"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=COLLECTIONS["products"].persist_directory
),
"support": Chroma(
collection_name=COLLECTIONS["support"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=COLLECTIONS["support"].persist_directory
),
"finance": Chroma(
collection_name=COLLECTIONS["finance"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=COLLECTIONS["finance"].persist_directory
try:
# Initialize Qdrant client with session state credentials
client = QdrantClient(
url=st.session_state.qdrant_url,
api_key=st.session_state.qdrant_api_key
)
}
return True

# Test connection
client.get_collections()
vector_size = 1536
st.session_state.databases = {}
for db_type, config in COLLECTIONS.items():
try:
client.get_collection(config.collection_name)
except Exception:
# Create collection if it doesn't exist
client.create_collection(
collection_name=config.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE)
)

st.session_state.databases[db_type] = Qdrant(
client=client,
collection_name=config.collection_name,
embeddings=st.session_state.embeddings
)

return True
except Exception as e:
st.error(f"Failed to connect to Qdrant: {str(e)}")
return False
return False

def process_document(file) -> List[Document]:
Expand Down Expand Up @@ -136,33 +149,62 @@ def create_routing_agent() -> Agent:
"1. For questions about products, features, specifications, or item details, or product manuals → return 'products'",
"2. For questions about help, guidance, troubleshooting, or customer service, FAQ, or guides → return 'support'",
"3. For questions about costs, revenue, pricing, or financial data, or financial reports and investments → return 'finance'",
"4. Return ONLY the database name, no other text or explanation"
"4. Return ONLY the database name, no other text or explanation",
"5. If you're not confident about the routing, return an empty response"
],
markdown=False,
show_tool_calls=False
)

def route_query(question: str) -> DatabaseType:
def route_query(question: str) -> Optional[DatabaseType]:
"""Route query by searching all databases and comparing relevance scores.
Returns None if no suitable database is found."""
try:
best_score = -1
best_db_type = None
all_scores = {} # Store all scores for debugging

# Search each database and compare relevance scores
for db_type, db in st.session_state.databases.items():
results = db.similarity_search_with_score(
question,
k=3
)

if results:
avg_score = sum(score for _, score in results) / len(results)
all_scores[db_type] = avg_score

if avg_score > best_score:
best_score = avg_score
best_db_type = db_type

confidence_threshold = 0.5
if best_score >= confidence_threshold and best_db_type:
st.success(f"Using vector similarity routing: {best_db_type} (confidence: {best_score:.3f})")
return best_db_type

st.warning(f"Low confidence scores (below {confidence_threshold}), falling back to LLM routing")

# Fallback to LLM routing
routing_agent = create_routing_agent()
response = routing_agent.run(question)

db_type = (response.content
.strip()
.lower()
.translate(str.maketrans('', '', '`\'"'))) # More elegant string cleaning

# Validate database type
if db_type not in COLLECTIONS:
st.warning(f"Invalid database type: {db_type}, defaulting to products")
return "products"
.translate(str.maketrans('', '', '`\'"')))

st.info(f"Routing question to {db_type} database")
return db_type
if db_type in COLLECTIONS:
st.success(f"Using LLM routing decision: {db_type}")
return db_type

st.warning("No suitable database found, will use web search fallback")
return None

except Exception as e:
st.error(f"Routing error: {str(e)}")
return "products"
return None

def create_fallback_agent(chat_model: BaseLanguageModel):
"""Create a LangGraph agent for web research."""
Expand All @@ -184,11 +226,12 @@ def web_research(query: str) -> str:

return agent

def query_database(db: Chroma, question: str) -> tuple[str, list]:
def query_database(db: Qdrant, question: str) -> tuple[str, list]:
"""Query the database and return answer and relevant documents"""
try:
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 4, "score_threshold": 0.3}
search_type="similarity",
search_kwargs={"k": 4}
)

relevant_docs = retriever.get_relevant_documents(question)
Expand All @@ -210,7 +253,8 @@ def query_database(db: Chroma, question: str) -> tuple[str, list]:

response = retrieval_chain.invoke({"input": question})
return response['answer'], relevant_docs
return _handle_web_fallback(question)

raise ValueError("No relevant documents found in database")

except Exception as e:
st.error(f"Error: {str(e)}")
Expand Down Expand Up @@ -244,25 +288,49 @@ def main():
st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚")
st.title("📚 RAG Agent with Database Routing")

# Sidebar for API key and database management
# Sidebar for API keys and configuration
with st.sidebar:
st.header("Configuration")

# OpenAI API Key
api_key = st.text_input(
"Enter OpenAI API Key:",
type="password",
value=st.session_state.openai_api_key,
key="api_key_input"
)

# Qdrant Configuration
qdrant_url = st.text_input(
"Enter Qdrant URL:",
value=st.session_state.qdrant_url,
help="Example: https://your-cluster.qdrant.tech"
)

qdrant_api_key = st.text_input(
"Enter Qdrant API Key:",
type="password",
value=st.session_state.qdrant_api_key
)

# Update session state
if api_key:
st.session_state.openai_api_key = api_key
if qdrant_url:
st.session_state.qdrant_url = qdrant_url
if qdrant_api_key:
st.session_state.qdrant_api_key = qdrant_api_key

# Initialize models if all credentials are provided
if (st.session_state.openai_api_key and
st.session_state.qdrant_url and
st.session_state.qdrant_api_key):
if initialize_models():
st.success("API Key set successfully!")
st.success("Connected to OpenAI and Qdrant successfully!")
else:
st.error("Invalid API Key")

if not st.session_state.openai_api_key:
st.warning("Please enter your OpenAI API key to continue")
st.error("Failed to initialize. Please check your credentials.")
else:
st.warning("Please enter all required credentials to continue")
st.stop()

st.markdown("---")
Expand Down Expand Up @@ -302,15 +370,19 @@ def main():
with st.spinner('Finding answer...'):
# Route the question
collection_type = route_query(question)
db = st.session_state.databases[collection_type]

# Display routing information
st.info(f"Routing question to: {COLLECTIONS[collection_type].name}")

# Get and display answer
answer, relevant_docs = query_database(db, question)
st.write("### Answer")
st.write(answer)
if collection_type is None:
# Use web search fallback directly
answer, relevant_docs = _handle_web_fallback(question)
st.write("### Answer (from web search)")
st.write(answer)
else:
# Display routing information and query the database
st.info(f"Routing question to: {COLLECTIONS[collection_type].name}")
db = st.session_state.databases[collection_type]
answer, relevant_docs = query_database(db, question)
st.write("### Answer")
st.write(answer)

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion rag_tutorials/rag_database_routing/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
langchain==0.3.12
langchain-community==0.3.12
langchain-core==0.3.28
chromadb==0.5.20
qdrant-client==1.12.1
streamlit>=1.29.0
pypdf>=4.0.0
sentence-transformers>=2.2.2
Expand Down

0 comments on commit f9c755d

Please sign in to comment.