Skip to content

Commit

Permalink
Merge pull request #46 from Madhuvod/rag-agent-cohere
Browse files Browse the repository at this point in the history
Added new demo
  • Loading branch information
Shubhamsaboo authored Dec 15, 2024
2 parents 0d1e509 + 67cd362 commit 72e788f
Show file tree
Hide file tree
Showing 3 changed files with 396 additions and 0 deletions.
68 changes: 68 additions & 0 deletions rag_tutorials/rag_agent_cohere/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# RAG Agent with Cohere 🤖

A RAG Agentic system built with Cohere's new model Command-r7b-12-2024, Qdrant for vector storage, Langchain for RAG and LangGraph for orchestration. This application allows users to upload documents, ask questions about them, and get AI-powered responses with fallback to web search when needed.

## Demo



## Features

- **Document Processing**
- PDF document upload and processing
- Automatic text chunking and embedding
- Vector storage in Qdrant cloud

- **Intelligent Querying**
- RAG-based document retrieval
- Similarity search with threshold filtering
- Automatic fallback to web search when no relevant documents found
- Source attribution for answers

- **Advanced Capabilities**
- DuckDuckGo web search integration
- LangGraph agent for web research
- Context-aware response generation
- Long answer summarization

- **Model Specific Features**
- Command-r7b-12-2024 model for Chat and RAG
- cohere embed-english-v3.0 model for embeddings
- create_react_agent function from langgraph
- DuckDuckGoSearchRun tool for web search

## Prerequisites

### 1. Cohere API Key
1. Go to [Cohere Platform](https://dashboard.cohere.ai/api-keys)
2. Sign up or log in to your account
3. Navigate to API Keys section
4. Create a new API key

### 2. Qdrant Cloud Setup
1. Visit [Qdrant Cloud](https://cloud.qdrant.io/)
2. Create an account or sign in
3. Create a new cluster
4. Get your credentials:
- Qdrant API Key: Found in API Keys section
- Qdrant URL: Your cluster URL (format: `https://xxx-xxx.aws.cloud.qdrant.io`)


## How to Run

1. Clone the repository:
```bash
git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git
cd rag_tutorials/rag_agent_cohere
```

2. Install dependencies:
```bash
pip install -r requirements.txt
```

```bash
streamlit run rag_agent_cohere.py
```


319 changes: 319 additions & 0 deletions rag_tutorials/rag_agent_cohere/rag_agent_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
import os
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_cohere import CohereEmbeddings, ChatCohere
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain import hub
import tempfile
from langgraph.prebuilt import create_react_agent
from langchain_community.tools import DuckDuckGoSearchRun
from typing import TypedDict, List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from time import sleep
from tenacity import retry, wait_exponential, stop_after_attempt


def init_session_state():
if 'api_keys_submitted' not in st.session_state:
st.session_state.api_keys_submitted = False
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'vectorstore' not in st.session_state:
st.session_state.vectorstore = None
if 'qdrant_api_key' not in st.session_state:
st.session_state.qdrant_api_key = ""
if 'qdrant_url' not in st.session_state:
st.session_state.qdrant_url = ""

def sidebar_api_form():
with st.sidebar:
st.header("API Credentials")

if st.session_state.api_keys_submitted:
st.success("API credentials verified")
if st.button("Reset Credentials"):
st.session_state.clear()
st.rerun()
return True

with st.form("api_credentials"):
cohere_key = st.text_input("Cohere API Key", type="password")
qdrant_key = st.text_input("Qdrant API Key", type="password", help="Enter your Qdrant API key")
qdrant_url = st.text_input("Qdrant URL",
placeholder="https://xyz-example.eu-central.aws.cloud.qdrant.io:6333",
help="Enter your Qdrant instance URL")

if st.form_submit_button("Submit Credentials"):
try:
client = QdrantClient(url=qdrant_url, api_key=qdrant_key, timeout=60)
client.get_collections()

st.session_state.cohere_api_key = cohere_key
st.session_state.qdrant_api_key = qdrant_key
st.session_state.qdrant_url = qdrant_url
st.session_state.api_keys_submitted = True

st.success("Credentials verified!")
st.rerun()
except Exception as e:
st.error(f"Qdrant connection failed: {str(e)}")
return False

def init_qdrant() -> QdrantClient:
if not st.session_state.get("qdrant_api_key"):
raise ValueError("Qdrant API key not provided")
if not st.session_state.get("qdrant_url"):
raise ValueError("Qdrant URL not provided")

return QdrantClient(url=st.session_state.qdrant_url,
api_key=st.session_state.qdrant_api_key,
timeout=60)

init_session_state()

if not sidebar_api_form():
st.info("Please enter your API credentials in the sidebar to continue.")
st.stop()

embedding = CohereEmbeddings(model="embed-english-v3.0",
cohere_api_key=st.session_state.cohere_api_key)

chat_model = ChatCohere(model="command-r7b-12-2024",
temperature=0.1,
max_tokens=512,
verbose=True,
cohere_api_key=st.session_state.cohere_api_key)

client = init_qdrant()

def process_document(file):
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(file.getvalue())
tmp_path = tmp_file.name

loader = PyPDFLoader(tmp_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)

os.unlink(tmp_path)

return texts
except Exception as e:
st.error(f"Error processing document: {e}")
return []

COLLECTION_NAME = "cohere_rag"

def create_vector_stores(texts):
"""Create and populate vector store with documents."""
try:
try:
client.create_collection(collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=1024,
distance=Distance.COSINE))
st.success(f"Created new collection: {COLLECTION_NAME}")
except Exception as e:
if "already exists" not in str(e).lower():
raise e

vector_store = QdrantVectorStore(client=client,
collection_name=COLLECTION_NAME,
embedding=embedding)

with st.spinner('Storing documents in Qdrant...'):
vector_store.add_documents(texts)
st.success("Documents successfully stored in Qdrant!")

return vector_store

except Exception as e:
st.error(f"Error in vector store creation: {str(e)}")
return None

# Define the state schema using TypedDict
class AgentState(TypedDict):
"""State schema for the agent."""
messages: List[HumanMessage | AIMessage | SystemMessage]
is_last_step: bool

class RateLimitedDuckDuckGo(DuckDuckGoSearchRun):
@retry(wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3))
def run(self, query: str) -> str:
"""Run search with rate limiting."""
try:
sleep(2) # Add delay between requests
return super().run(query)
except Exception as e:
if "Ratelimit" in str(e):
sleep(5) # Longer delay on rate limit
return super().run(query)
raise e

def create_fallback_agent(chat_model: BaseLanguageModel):
"""Create a LangGraph agent for web research."""

def web_research(query: str) -> str:
"""Web search with result formatting."""
try:
search = DuckDuckGoSearchRun(num_results=5)
results = search.run(query)
return results
except Exception as e:
return f"Search failed: {str(e)}. Providing answer based on general knowledge."

tools = [web_research]

agent = create_react_agent(model=chat_model,
tools=tools,
debug=False)

return agent

def process_query(vectorstore, query) -> tuple[str, list]:
"""Process a query using RAG with fallback to web search."""
try:
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 10,
"score_threshold": 0.7
}
)

relevant_docs = retriever.get_relevant_documents(query)

if relevant_docs:
retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
combine_docs_chain = create_stuff_documents_chain(chat_model, retrieval_qa_prompt)
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
response = retrieval_chain.invoke({"input": query})
return response['answer'], relevant_docs

else:
st.info("No relevant documents found. Searching web...")
fallback_agent = create_fallback_agent(chat_model)

with st.spinner('Researching...'):
agent_input = {
"messages": [
HumanMessage(content=f"""Please thoroughly research the question: '{query}' and provide a detailed and comprehensive response. Make sure to gather the latest information from credible sources. Minimum 400 words.""")
],
"is_last_step": False
}

config = {"recursion_limit": 100}

try:
response = fallback_agent.invoke(agent_input, config=config)

if isinstance(response, dict) and "messages" in response:
last_message = response["messages"][-1]
answer = last_message.content if hasattr(last_message, 'content') else str(last_message)

return f"""Comprehensive Research Results:
{answer}
""", []

except Exception as agent_error:
fallback_response = chat_model.invoke(f"Please provide a general answer to: {query}").content
return f"Web search unavailable. General response: {fallback_response}", []

except Exception as e:
st.error(f"Error: {str(e)}")
return "I encountered an error. Please try rephrasing your question.", []

def post_process(answer, sources):
"""Post-process the answer and format sources."""
answer = answer.strip()

# Summarize long answers
if len(answer) > 500:
summary_prompt = f"Summarize the following answer in 2-3 sentences: {answer}"
summary = chat_model.invoke(summary_prompt).content
answer = f"{summary}\n\nFull Answer: {answer}"

formatted_sources = []
for i, source in enumerate(sources, 1):
formatted_source = f"{i}. {source.page_content[:200]}..."
formatted_sources.append(formatted_source)
return answer, formatted_sources

st.title("RAG Agent with Cohere 🤖")

uploaded_file = st.file_uploader("Choose a PDF or Image File", type=["pdf", "jpg", "jpeg"])

if uploaded_file is not None and 'processed_file' not in st.session_state:
with st.spinner('Processing file... This may take a while for images.'):
texts = process_document(uploaded_file)
vectorstore = create_vector_stores(texts)
if vectorstore:
st.session_state.vectorstore = vectorstore
st.session_state.processed_file = True
st.success('File uploaded and processed successfully!')
else:
st.error('Failed to process file. Please try again.')

for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])

if query := st.chat_input("Ask a question about the document:"):
st.session_state.chat_history.append({"role": "user", "content": query})
with st.chat_message("user"):
st.markdown(query)

if st.session_state.vectorstore:
with st.chat_message("assistant"):
try:
answer, sources = process_query(st.session_state.vectorstore, query)
st.markdown(answer)

if sources:
with st.expander("Sources"):
for source in sources:
st.markdown(f"- {source.page_content[:200]}...")

st.session_state.chat_history.append({
"role": "assistant",
"content": answer
})

except Exception as e:
st.error(f"Error: {str(e)}")
st.info("Please try asking your question again.")
else:
st.error("Please upload a document first.")

with st.sidebar:
st.divider()
col1, col2 = st.columns(2)
with col1:
if st.button('Clear Chat History'):
st.session_state.chat_history = []
st.rerun()
with col2:
if st.button('Clear All Data'):
try:
collections = client.get_collections().collections
collection_names = [col.name for col in collections]

if COLLECTION_NAME in collection_names:
client.delete_collection(COLLECTION_NAME)
if f"{COLLECTION_NAME}_compressed" in collection_names:
client.delete_collection(f"{COLLECTION_NAME}_compressed")

st.session_state.vectorstore = None
st.session_state.chat_history = []
st.success("All data cleared successfully!")
st.rerun()
except Exception as e:
st.error(f"Error clearing data: {str(e)}")
9 changes: 9 additions & 0 deletions rag_tutorials/rag_agent_cohere/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
langgraph==0.2.53
langchain==0.3.11
langchain-community==0.0.10
cohere==5.11.4
qdrant-client==1.12.1
duckduckgo-search==4.1.1
streamlit==1.40.2
langchain-cohere==0.3.2
langchain-qdrant==0.2.0

0 comments on commit 72e788f

Please sign in to comment.