-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #46 from Madhuvod/rag-agent-cohere
Added new demo
- Loading branch information
Showing
3 changed files
with
396 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# RAG Agent with Cohere 🤖 | ||
|
||
A RAG Agentic system built with Cohere's new model Command-r7b-12-2024, Qdrant for vector storage, Langchain for RAG and LangGraph for orchestration. This application allows users to upload documents, ask questions about them, and get AI-powered responses with fallback to web search when needed. | ||
|
||
## Demo | ||
|
||
|
||
|
||
## Features | ||
|
||
- **Document Processing** | ||
- PDF document upload and processing | ||
- Automatic text chunking and embedding | ||
- Vector storage in Qdrant cloud | ||
|
||
- **Intelligent Querying** | ||
- RAG-based document retrieval | ||
- Similarity search with threshold filtering | ||
- Automatic fallback to web search when no relevant documents found | ||
- Source attribution for answers | ||
|
||
- **Advanced Capabilities** | ||
- DuckDuckGo web search integration | ||
- LangGraph agent for web research | ||
- Context-aware response generation | ||
- Long answer summarization | ||
|
||
- **Model Specific Features** | ||
- Command-r7b-12-2024 model for Chat and RAG | ||
- cohere embed-english-v3.0 model for embeddings | ||
- create_react_agent function from langgraph | ||
- DuckDuckGoSearchRun tool for web search | ||
|
||
## Prerequisites | ||
|
||
### 1. Cohere API Key | ||
1. Go to [Cohere Platform](https://dashboard.cohere.ai/api-keys) | ||
2. Sign up or log in to your account | ||
3. Navigate to API Keys section | ||
4. Create a new API key | ||
|
||
### 2. Qdrant Cloud Setup | ||
1. Visit [Qdrant Cloud](https://cloud.qdrant.io/) | ||
2. Create an account or sign in | ||
3. Create a new cluster | ||
4. Get your credentials: | ||
- Qdrant API Key: Found in API Keys section | ||
- Qdrant URL: Your cluster URL (format: `https://xxx-xxx.aws.cloud.qdrant.io`) | ||
|
||
|
||
## How to Run | ||
|
||
1. Clone the repository: | ||
```bash | ||
git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git | ||
cd rag_tutorials/rag_agent_cohere | ||
``` | ||
|
||
2. Install dependencies: | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
```bash | ||
streamlit run rag_agent_cohere.py | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,319 @@ | ||
import os | ||
import streamlit as st | ||
from langchain_community.document_loaders import PyPDFLoader | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain_cohere import CohereEmbeddings, ChatCohere | ||
from langchain_qdrant import QdrantVectorStore | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.models import Distance, VectorParams | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.chains import create_retrieval_chain | ||
from langchain import hub | ||
import tempfile | ||
from langgraph.prebuilt import create_react_agent | ||
from langchain_community.tools import DuckDuckGoSearchRun | ||
from typing import TypedDict, List | ||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | ||
from time import sleep | ||
from tenacity import retry, wait_exponential, stop_after_attempt | ||
|
||
|
||
def init_session_state(): | ||
if 'api_keys_submitted' not in st.session_state: | ||
st.session_state.api_keys_submitted = False | ||
if 'chat_history' not in st.session_state: | ||
st.session_state.chat_history = [] | ||
if 'vectorstore' not in st.session_state: | ||
st.session_state.vectorstore = None | ||
if 'qdrant_api_key' not in st.session_state: | ||
st.session_state.qdrant_api_key = "" | ||
if 'qdrant_url' not in st.session_state: | ||
st.session_state.qdrant_url = "" | ||
|
||
def sidebar_api_form(): | ||
with st.sidebar: | ||
st.header("API Credentials") | ||
|
||
if st.session_state.api_keys_submitted: | ||
st.success("API credentials verified") | ||
if st.button("Reset Credentials"): | ||
st.session_state.clear() | ||
st.rerun() | ||
return True | ||
|
||
with st.form("api_credentials"): | ||
cohere_key = st.text_input("Cohere API Key", type="password") | ||
qdrant_key = st.text_input("Qdrant API Key", type="password", help="Enter your Qdrant API key") | ||
qdrant_url = st.text_input("Qdrant URL", | ||
placeholder="https://xyz-example.eu-central.aws.cloud.qdrant.io:6333", | ||
help="Enter your Qdrant instance URL") | ||
|
||
if st.form_submit_button("Submit Credentials"): | ||
try: | ||
client = QdrantClient(url=qdrant_url, api_key=qdrant_key, timeout=60) | ||
client.get_collections() | ||
|
||
st.session_state.cohere_api_key = cohere_key | ||
st.session_state.qdrant_api_key = qdrant_key | ||
st.session_state.qdrant_url = qdrant_url | ||
st.session_state.api_keys_submitted = True | ||
|
||
st.success("Credentials verified!") | ||
st.rerun() | ||
except Exception as e: | ||
st.error(f"Qdrant connection failed: {str(e)}") | ||
return False | ||
|
||
def init_qdrant() -> QdrantClient: | ||
if not st.session_state.get("qdrant_api_key"): | ||
raise ValueError("Qdrant API key not provided") | ||
if not st.session_state.get("qdrant_url"): | ||
raise ValueError("Qdrant URL not provided") | ||
|
||
return QdrantClient(url=st.session_state.qdrant_url, | ||
api_key=st.session_state.qdrant_api_key, | ||
timeout=60) | ||
|
||
init_session_state() | ||
|
||
if not sidebar_api_form(): | ||
st.info("Please enter your API credentials in the sidebar to continue.") | ||
st.stop() | ||
|
||
embedding = CohereEmbeddings(model="embed-english-v3.0", | ||
cohere_api_key=st.session_state.cohere_api_key) | ||
|
||
chat_model = ChatCohere(model="command-r7b-12-2024", | ||
temperature=0.1, | ||
max_tokens=512, | ||
verbose=True, | ||
cohere_api_key=st.session_state.cohere_api_key) | ||
|
||
client = init_qdrant() | ||
|
||
def process_document(file): | ||
try: | ||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | ||
tmp_file.write(file.getvalue()) | ||
tmp_path = tmp_file.name | ||
|
||
loader = PyPDFLoader(tmp_path) | ||
documents = loader.load() | ||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | ||
texts = text_splitter.split_documents(documents) | ||
|
||
os.unlink(tmp_path) | ||
|
||
return texts | ||
except Exception as e: | ||
st.error(f"Error processing document: {e}") | ||
return [] | ||
|
||
COLLECTION_NAME = "cohere_rag" | ||
|
||
def create_vector_stores(texts): | ||
"""Create and populate vector store with documents.""" | ||
try: | ||
try: | ||
client.create_collection(collection_name=COLLECTION_NAME, | ||
vectors_config=VectorParams(size=1024, | ||
distance=Distance.COSINE)) | ||
st.success(f"Created new collection: {COLLECTION_NAME}") | ||
except Exception as e: | ||
if "already exists" not in str(e).lower(): | ||
raise e | ||
|
||
vector_store = QdrantVectorStore(client=client, | ||
collection_name=COLLECTION_NAME, | ||
embedding=embedding) | ||
|
||
with st.spinner('Storing documents in Qdrant...'): | ||
vector_store.add_documents(texts) | ||
st.success("Documents successfully stored in Qdrant!") | ||
|
||
return vector_store | ||
|
||
except Exception as e: | ||
st.error(f"Error in vector store creation: {str(e)}") | ||
return None | ||
|
||
# Define the state schema using TypedDict | ||
class AgentState(TypedDict): | ||
"""State schema for the agent.""" | ||
messages: List[HumanMessage | AIMessage | SystemMessage] | ||
is_last_step: bool | ||
|
||
class RateLimitedDuckDuckGo(DuckDuckGoSearchRun): | ||
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), | ||
stop=stop_after_attempt(3)) | ||
def run(self, query: str) -> str: | ||
"""Run search with rate limiting.""" | ||
try: | ||
sleep(2) # Add delay between requests | ||
return super().run(query) | ||
except Exception as e: | ||
if "Ratelimit" in str(e): | ||
sleep(5) # Longer delay on rate limit | ||
return super().run(query) | ||
raise e | ||
|
||
def create_fallback_agent(chat_model: BaseLanguageModel): | ||
"""Create a LangGraph agent for web research.""" | ||
|
||
def web_research(query: str) -> str: | ||
"""Web search with result formatting.""" | ||
try: | ||
search = DuckDuckGoSearchRun(num_results=5) | ||
results = search.run(query) | ||
return results | ||
except Exception as e: | ||
return f"Search failed: {str(e)}. Providing answer based on general knowledge." | ||
|
||
tools = [web_research] | ||
|
||
agent = create_react_agent(model=chat_model, | ||
tools=tools, | ||
debug=False) | ||
|
||
return agent | ||
|
||
def process_query(vectorstore, query) -> tuple[str, list]: | ||
"""Process a query using RAG with fallback to web search.""" | ||
try: | ||
retriever = vectorstore.as_retriever( | ||
search_type="similarity_score_threshold", | ||
search_kwargs={ | ||
"k": 10, | ||
"score_threshold": 0.7 | ||
} | ||
) | ||
|
||
relevant_docs = retriever.get_relevant_documents(query) | ||
|
||
if relevant_docs: | ||
retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat") | ||
combine_docs_chain = create_stuff_documents_chain(chat_model, retrieval_qa_prompt) | ||
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain) | ||
response = retrieval_chain.invoke({"input": query}) | ||
return response['answer'], relevant_docs | ||
|
||
else: | ||
st.info("No relevant documents found. Searching web...") | ||
fallback_agent = create_fallback_agent(chat_model) | ||
|
||
with st.spinner('Researching...'): | ||
agent_input = { | ||
"messages": [ | ||
HumanMessage(content=f"""Please thoroughly research the question: '{query}' and provide a detailed and comprehensive response. Make sure to gather the latest information from credible sources. Minimum 400 words.""") | ||
], | ||
"is_last_step": False | ||
} | ||
|
||
config = {"recursion_limit": 100} | ||
|
||
try: | ||
response = fallback_agent.invoke(agent_input, config=config) | ||
|
||
if isinstance(response, dict) and "messages" in response: | ||
last_message = response["messages"][-1] | ||
answer = last_message.content if hasattr(last_message, 'content') else str(last_message) | ||
|
||
return f"""Comprehensive Research Results: | ||
{answer} | ||
""", [] | ||
|
||
except Exception as agent_error: | ||
fallback_response = chat_model.invoke(f"Please provide a general answer to: {query}").content | ||
return f"Web search unavailable. General response: {fallback_response}", [] | ||
|
||
except Exception as e: | ||
st.error(f"Error: {str(e)}") | ||
return "I encountered an error. Please try rephrasing your question.", [] | ||
|
||
def post_process(answer, sources): | ||
"""Post-process the answer and format sources.""" | ||
answer = answer.strip() | ||
|
||
# Summarize long answers | ||
if len(answer) > 500: | ||
summary_prompt = f"Summarize the following answer in 2-3 sentences: {answer}" | ||
summary = chat_model.invoke(summary_prompt).content | ||
answer = f"{summary}\n\nFull Answer: {answer}" | ||
|
||
formatted_sources = [] | ||
for i, source in enumerate(sources, 1): | ||
formatted_source = f"{i}. {source.page_content[:200]}..." | ||
formatted_sources.append(formatted_source) | ||
return answer, formatted_sources | ||
|
||
st.title("RAG Agent with Cohere 🤖") | ||
|
||
uploaded_file = st.file_uploader("Choose a PDF or Image File", type=["pdf", "jpg", "jpeg"]) | ||
|
||
if uploaded_file is not None and 'processed_file' not in st.session_state: | ||
with st.spinner('Processing file... This may take a while for images.'): | ||
texts = process_document(uploaded_file) | ||
vectorstore = create_vector_stores(texts) | ||
if vectorstore: | ||
st.session_state.vectorstore = vectorstore | ||
st.session_state.processed_file = True | ||
st.success('File uploaded and processed successfully!') | ||
else: | ||
st.error('Failed to process file. Please try again.') | ||
|
||
for message in st.session_state.chat_history: | ||
with st.chat_message(message["role"]): | ||
st.markdown(message["content"]) | ||
|
||
if query := st.chat_input("Ask a question about the document:"): | ||
st.session_state.chat_history.append({"role": "user", "content": query}) | ||
with st.chat_message("user"): | ||
st.markdown(query) | ||
|
||
if st.session_state.vectorstore: | ||
with st.chat_message("assistant"): | ||
try: | ||
answer, sources = process_query(st.session_state.vectorstore, query) | ||
st.markdown(answer) | ||
|
||
if sources: | ||
with st.expander("Sources"): | ||
for source in sources: | ||
st.markdown(f"- {source.page_content[:200]}...") | ||
|
||
st.session_state.chat_history.append({ | ||
"role": "assistant", | ||
"content": answer | ||
}) | ||
|
||
except Exception as e: | ||
st.error(f"Error: {str(e)}") | ||
st.info("Please try asking your question again.") | ||
else: | ||
st.error("Please upload a document first.") | ||
|
||
with st.sidebar: | ||
st.divider() | ||
col1, col2 = st.columns(2) | ||
with col1: | ||
if st.button('Clear Chat History'): | ||
st.session_state.chat_history = [] | ||
st.rerun() | ||
with col2: | ||
if st.button('Clear All Data'): | ||
try: | ||
collections = client.get_collections().collections | ||
collection_names = [col.name for col in collections] | ||
|
||
if COLLECTION_NAME in collection_names: | ||
client.delete_collection(COLLECTION_NAME) | ||
if f"{COLLECTION_NAME}_compressed" in collection_names: | ||
client.delete_collection(f"{COLLECTION_NAME}_compressed") | ||
|
||
st.session_state.vectorstore = None | ||
st.session_state.chat_history = [] | ||
st.success("All data cleared successfully!") | ||
st.rerun() | ||
except Exception as e: | ||
st.error(f"Error clearing data: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
langgraph==0.2.53 | ||
langchain==0.3.11 | ||
langchain-community==0.0.10 | ||
cohere==5.11.4 | ||
qdrant-client==1.12.1 | ||
duckduckgo-search==4.1.1 | ||
streamlit==1.40.2 | ||
langchain-cohere==0.3.2 | ||
langchain-qdrant==0.2.0 |