From 315092de278c0e2fe5c5d41b6ad840d4c49d232b Mon Sep 17 00:00:00 2001 From: Madhu Date: Fri, 27 Dec 2024 02:45:11 +0530 Subject: [PATCH] new demo - crag --- rag_tutorials/corrective_rag/README.md | 92 ++++ .../corrective_rag/corrective_rag.py | 453 ++++++++++++++++++ rag_tutorials/corrective_rag/requirements.txt | 19 + 3 files changed, 564 insertions(+) create mode 100644 rag_tutorials/corrective_rag/README.md create mode 100644 rag_tutorials/corrective_rag/corrective_rag.py create mode 100644 rag_tutorials/corrective_rag/requirements.txt diff --git a/rag_tutorials/corrective_rag/README.md b/rag_tutorials/corrective_rag/README.md new file mode 100644 index 00000000..a149b42b --- /dev/null +++ b/rag_tutorials/corrective_rag/README.md @@ -0,0 +1,92 @@ +# Corrective RAG Demo + +This project demonstrates Corrective RAG (Retrieval Augmented Generation), an advanced approach to RAG that incorporates self-reflection / self-grading on retrieved documents - document relevance checking, query transformation, and web search fallback mechanisms to improve the quality of responses by far. Complete explanation of CRAG down below. + +## Features + +- **Smart Document Retrieval**: Uses Qdrant vector store for efficient document retrieval +- **Document Relevance Grading**: Employs Claude 3 to assess document relevance +- **Query Transformation**: Improves search results by optimizing queries when needed +- **Web Search Fallback**: Uses Tavily API for web search when local documents aren't sufficient +- **Multi-Model Approach**: Combines OpenAI embeddings and Claude 3 for different tasks +- **Interactive UI**: Built with Streamlit for easy document upload and querying + +## How to Run? + +1. **Clone the Repository**: + ```bash + git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git + cd rag_tutorials/corrective_rag + ``` + +2. **Install Dependencies**: + ```bash + pip install -r requirements.txt + ``` + +3. **Set Up API Keys**: + You'll need to obtain the following API keys: + - OpenAI API key (for embeddings) + - Anthropic API key (for Claude 3.5 sonnet as llm) + - Tavily API key (for web search) + - Qdrant API key and URL + +4. **Run the Application**: + ```bash + streamlit run corrective_rag.py + ``` + +5. **Use the Application**: + - Upload documents or provide URLs + - Enter your questions in the query box + - View the step-by-step Corrective RAG process + - Get comprehensive answers + +## Technologies Used + +- **LangChain**: For RAG orchestration and chains +- **LangGraph**: For workflow management +- **Qdrant**: Vector database for document storage +- **Claude 3.5 sonnet**: Main language model for analysis and generation +- **OpenAI**: For document embeddings +- **Tavily**: For web search capabilities +- **Streamlit**: For the user interface + +## CRAG Step by Step Explanation + +1. Initial Retrieval + +A user query is presented to the system.   +The system uses an existing retriever model to gather relevant documents from a knowledge base. This retriever could be any existing model.   +2. Evaluation of Retrieved Documents + +A lightweight retrieval evaluator is used to assess the relevance of each retrieved document to the user query.   +The evaluator assigns a confidence score to each document, indicating how confident it is in the relevance of the document to the query. +   +3. Action Trigger + +Based on the confidence scores, the system categorizes the retrieved documents and decides on the necessary action for each document.   + +Correct: If the confidence score of a retrieved document is above a certain threshold, the document is marked as "Correct".   +Incorrect: If the confidence score of a retrieved document is below a certain threshold, the document is marked as "Incorrect".   +Ambiguous: If the confidence score falls between the thresholds for "Correct" and "Incorrect", the document is marked as "Ambiguous".   + +4. Handling of Retrieved Documents + +Correct Documents: These documents undergo a knowledge refinement process.   + +Decomposition: The document is segmented into smaller knowledge strips, typically consisting of a few sentences each.   +Filtering: The retrieval evaluator is used again to assess the relevance of each knowledge strip. Irrelevant strips are discarded.   +Recomposition: The remaining relevant knowledge strips are recombined to form a refined representation of the essential knowledge from the document.   +Incorrect Documents: These documents are discarded, and the system resorts to web searches for additional information.   + +Query Rewriting: The user query is rewritten into a form suitable for web searches, typically focusing on keywords.   +Web Search: The system uses a web search API to find web pages related to the rewritten query. Authoritative sources like Wikipedia are preferred.   +Knowledge Selection: The content of the web pages is transcribed, and the knowledge refinement process (decomposition, filtering, and recomposition) is applied to extract the most relevant information.   +Ambiguous Documents: The system combines the refined knowledge from the "Correct" documents and the external knowledge from the web searches to provide a comprehensive set of information for the generator.   + +5. Generation of Response + +The refined knowledge from the retrieved documents and/or web searches is presented to a generative language model.   +The language model generates a response to the user query based on this knowledge + diff --git a/rag_tutorials/corrective_rag/corrective_rag.py b/rag_tutorials/corrective_rag/corrective_rag.py new file mode 100644 index 00000000..bd407dfe --- /dev/null +++ b/rag_tutorials/corrective_rag/corrective_rag.py @@ -0,0 +1,453 @@ +from langchain import hub +from langchain.output_parsers import PydanticOutputParser +from langchain_core.output_parsers import StrOutputParser +from langchain.schema import Document +from pydantic import BaseModel, Field +import streamlit as st +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import PyPDFLoader, TextLoader, WebBaseLoader +from langchain_community.tools import TavilySearchResults +from langchain_community.vectorstores import Qdrant +from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain_core.messages import HumanMessage +from langgraph.graph import END, StateGraph +from typing import Dict, TypedDict +from langchain_core.prompts import PromptTemplate +import pprint +import yaml +import nest_asyncio +from qdrant_client import QdrantClient +from qdrant_client.models import Distance, VectorParams +import tempfile +import os +from langchain_anthropic import ChatAnthropic +from tenacity import retry, stop_after_attempt, wait_exponential + + +nest_asyncio.apply() + +retriever = None + +def initialize_session_state(): + """Initialize session state variables for API keys and URLs.""" + if 'initialized' not in st.session_state: + st.session_state.initialized = False + # Initialize API keys and URLs + st.session_state.anthropic_api_key = "" + st.session_state.openai_api_key = "" + st.session_state.tavily_api_key = "" + st.session_state.qdrant_api_key = "" + st.session_state.qdrant_url = "http://localhost:6333" + st.session_state.doc_url = "https://arxiv.org/pdf/2307.09288.pdf" + +def setup_sidebar(): + """Setup sidebar for API keys and configuration.""" + with st.sidebar: + st.subheader("API Configuration") + st.session_state.anthropic_api_key = st.text_input("Anthropic API Key", value=st.session_state.anthropic_api_key, type="password", help="Required for Claude 3 model") + st.session_state.openai_api_key = st.text_input("OpenAI API Key", value=st.session_state.openai_api_key, type="password") + st.session_state.tavily_api_key = st.text_input("Tavily API Key", value=st.session_state.tavily_api_key, type="password") + st.session_state.qdrant_url = st.text_input("Qdrant URL", value=st.session_state.qdrant_url) + st.session_state.qdrant_api_key = st.text_input("Qdrant API Key", value=st.session_state.qdrant_api_key, type="password") + st.session_state.doc_url = st.text_input("Document URL", value=st.session_state.doc_url) + + if not all([st.session_state.openai_api_key, st.session_state.anthropic_api_key, st.session_state.qdrant_url]): + st.warning("Please provide the required API keys and URLs") + st.stop() + + st.session_state.initialized = True + +initialize_session_state() +setup_sidebar() + +# Use session state variables instead of config +openai_api_key = st.session_state.openai_api_key +tavily_api_key = st.session_state.tavily_api_key +anthropic_api_key = st.session_state.anthropic_api_key + +# Update embeddings initialization +embeddings = OpenAIEmbeddings( + model="text-embedding-3-small", + api_key=st.session_state.openai_api_key +) + +# Update Qdrant client initialization +client = QdrantClient( + url=st.session_state.qdrant_url, + api_key=st.session_state.qdrant_api_key +) + +@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) +def execute_tavily_search(tool, query): + return tool.invoke({"query": query}) + +def web_search(state): + """Web search based on the re-phrased question using Tavily API.""" + print("~-web search-~") + state_dict = state["keys"] + question = state_dict["question"] + documents = state_dict["documents"] + + # Create progress placeholder + progress_placeholder = st.empty() + progress_placeholder.info("Initiating web search...") + + try: + # Validate Tavily API key + if not st.session_state.tavily_api_key: + progress_placeholder.warning("Tavily API key not provided - skipping web search") + return {"keys": {"documents": documents, "question": question}} + + progress_placeholder.info("Configuring search tool...") + + # Initialize Tavily search tool + tool = TavilySearchResults( + api_key=st.session_state.tavily_api_key, + max_results=3, + search_depth="advanced" + ) + + # Execute search with retry logic + progress_placeholder.info("Executing search query...") + try: + search_results = execute_tavily_search(tool, question) + except Exception as search_error: + progress_placeholder.error(f"Search failed after retries: {str(search_error)}") + return {"keys": {"documents": documents, "question": question}} + + if not search_results: + progress_placeholder.warning("No search results found") + return {"keys": {"documents": documents, "question": question}} + + # Process results + progress_placeholder.info("Processing search results...") + web_results = [] + for result in search_results: + # Extract and format relevant information + content = ( + f"Title: {result.get('title', 'No title')}\n" + f"Content: {result.get('content', 'No content')}\n" + ) + web_results.append(content) + + # Create document from results + web_document = Document( + page_content="\n\n".join(web_results), + metadata={ + "source": "tavily_search", + "query": question, + "result_count": len(web_results) + } + ) + documents.append(web_document) + + progress_placeholder.success(f"Successfully added {len(web_results)} search results") + + except Exception as error: + error_msg = f"Web search error: {str(error)}" + print(error_msg) + progress_placeholder.error(error_msg) + finally: + progress_placeholder.empty() + return {"keys": {"documents": documents, "question": question}} + + +def load_documents(file_or_url: str, is_url: bool = True) -> list: + try: + if is_url: + loader = WebBaseLoader(file_or_url) + loader.requests_per_second = 1 + else: + file_extension = os.path.splitext(file_or_url)[1].lower() + if file_extension == '.pdf': + loader = PyPDFLoader(file_or_url) + elif file_extension in ['.txt', '.md']: + loader = TextLoader(file_or_url) + else: + raise ValueError(f"Unsupported file type: {file_extension}") + + return loader.load() + except Exception as e: + st.error(f"Error loading document: {str(e)}") + return [] + +st.subheader("Document Input") +input_option = st.radio("Choose input method:", ["URL", "File Upload"]) + +docs = None + +if input_option == "URL": + url = st.text_input("Enter document URL:", value=st.session_state.doc_url) + if url: + docs = load_documents(url, is_url=True) +else: + uploaded_file = st.file_uploader("Upload a document", type=['pdf', 'txt', 'md']) + if uploaded_file: + # Create a temporary file to store the upload + with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: + tmp_file.write(uploaded_file.getvalue()) + docs = load_documents(tmp_file.name, is_url=False) + # Clean up the temporary file + os.unlink(tmp_file.name) + +if docs: + text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + chunk_size=500, chunk_overlap=100 + ) + all_splits = text_splitter.split_documents(docs) + + client = QdrantClient(url=st.session_state.qdrant_url, api_key=st.session_state.qdrant_api_key) + collection_name = "rag-qdrant" + + try: + # Try to delete the collection if it exists + client.delete_collection(collection_name) + except Exception: + pass + + client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), + ) + + # Create vectorstore + vectorstore = Qdrant( + client=client, + collection_name=collection_name, + embeddings=embeddings, + ) + + # Add documents to the vectorstore + vectorstore.add_documents(all_splits) + retriever = vectorstore.as_retriever() + + +class GraphState(TypedDict): + keys: Dict[str, any] + + +def retrieve(state): + print("~-retrieve-~") + state_dict = state["keys"] + question = state_dict["question"] + + if retriever is None: + return {"keys": {"documents": [], "question": question}} + + documents = retriever.get_relevant_documents(question) + return {"keys": {"documents": documents, "question": question}} + + +def generate(state): + """Generate answer using Claude 3 model""" + print("~-generate-~") + state_dict = state["keys"] + question, documents = state_dict["question"], state_dict["documents"] + try: + prompt = PromptTemplate(template="""Based on the following context, please answer the question. + Context: {context} + Question: {question} + Answer:""", input_variables=["context", "question"]) + llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", api_key=st.session_state.anthropic_api_key, + temperature=0, max_tokens=1000) + context = "\n\n".join(doc.page_content for doc in documents) + + # Create and run chain + rag_chain = ( + {"context": lambda x: context, "question": lambda x: question} + | prompt + | llm + | StrOutputParser() + ) + + generation = rag_chain.invoke({}) + + return { + "keys": { + "documents": documents, + "question": question, + "generation": generation + } + } + + except Exception as e: + error_msg = f"Error in generate function: {str(e)}" + print(error_msg) + st.error(error_msg) + return {"keys": {"documents": documents, "question": question, + "generation": "Sorry, I encountered an error while generating the response."}} + +def grade_documents(state): + """Determines whether the retrieved documents are relevant.""" + print("~-check relevance-~") + state_dict = state["keys"] + question = state_dict["question"] + documents = state_dict["documents"] + + llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", api_key=st.session_state.anthropic_api_key, + temperature=0, max_tokens=1000) + + prompt = PromptTemplate(template="""You are grading the relevance of a retrieved document to a user question. + Return ONLY a JSON object with a "score" field that is either "yes" or "no". + Do not include any other text or explanation. + + Document: {context} + Question: {question} + + Rules: + - Check for related keywords or semantic meaning + - Use lenient grading to only filter clear mismatches + - Return exactly like this example: {{"score": "yes"}} or {{"score": "no"}}""", + input_variables=["context", "question"]) + + chain = ( + prompt + | llm + | StrOutputParser() + ) + + filtered_docs = [] + search = "No" + + for d in documents: + try: + response = chain.invoke({"question": question, "context": d.page_content}) + import re + json_match = re.search(r'\{.*\}', response) + if json_match: + response = json_match.group() + + import json + score = json.loads(response) + + if score.get("score") == "yes": + print("~-grade: document relevant-~") + filtered_docs.append(d) + else: + print("~-grade: document not relevant-~") + search = "Yes" + + except Exception as e: + print(f"Error grading document: {str(e)}") + # On error, keep the document to be safe + filtered_docs.append(d) + continue + + return {"keys": {"documents": filtered_docs, "question": question, "run_web_search": search}} + + +def transform_query(state): + """Transform the query to produce a better question.""" + print("~-transform query-~") + state_dict = state["keys"] + question = state_dict["question"] + documents = state_dict["documents"] + + # Create a prompt template + prompt = PromptTemplate( + template="""Generate a search-optimized version of this question by + analyzing its core semantic meaning and intent. + \n ------- \n + {question} + \n ------- \n + Return only the improved question with no additional text:""", + input_variables=["question"], + ) + + # Use Claude instead of Gemini + llm = ChatAnthropic( + model="claude-3-5-sonnet-20240620", + anthropic_api_key=st.session_state.anthropic_api_key, + temperature=0, + max_tokens=1000 + ) + + # Prompt + chain = prompt | llm | StrOutputParser() + better_question = chain.invoke({"question": question}) + + return { + "keys": {"documents": documents, "question": better_question} + } + + +def decide_to_generate(state): + print("~-decide to generate-~") + state_dict = state["keys"] + search = state_dict["run_web_search"] + + if search == "Yes": + + print("~-decision: transform query and run web search-~") + return "transform_query" + else: + print("~-decision: generate-~") + return "generate" + +def format_document(doc: Document) -> str: + return f""" + Source: {doc.metadata.get('source', 'Unknown')} + Title: {doc.metadata.get('title', 'No title')} + Content: {doc.page_content[:200]}... + """ + +def format_state(state: dict) -> str: + formatted = {} + + for key, value in state.items(): + if key == "documents": + formatted[key] = [format_document(doc) for doc in value] + else: + formatted[key] = value + + return formatted + + +workflow = StateGraph(GraphState) + +# Define the nodes by langgraph +workflow.add_node("retrieve", retrieve) +workflow.add_node("grade_documents", grade_documents) +workflow.add_node("generate", generate) +workflow.add_node("transform_query", transform_query) +workflow.add_node("web_search", web_search) + +# Build graph +workflow.set_entry_point("retrieve") +workflow.add_edge("retrieve", "grade_documents") +workflow.add_conditional_edges( + "grade_documents", + decide_to_generate, + { + "transform_query": "transform_query", + "generate": "generate", + }, +) +workflow.add_edge("transform_query", "web_search") +workflow.add_edge("web_search", "generate") +workflow.add_edge("generate", END) + +app = workflow.compile() + +st.title("Corrective RAG Demo") + +st.text("A possible query: What are the experiment results and ablation studies in this research paper?") + +# User input +user_question = st.text_input("Please enter your question:") + +if user_question: + inputs = { + "keys": { + "question": user_question, + } + } + + for output in app.stream(inputs): + for key, value in output.items(): + with st.expander(f"Step '{key}':"): + st.text(pprint.pformat(format_state(value["keys"]), indent=2, width=80)) + + final_generation = value['keys'].get('generation', 'No final generation produced.') + st.subheader("Final Generation:") + st.write(final_generation) diff --git a/rag_tutorials/corrective_rag/requirements.txt b/rag_tutorials/corrective_rag/requirements.txt new file mode 100644 index 00000000..ed8f0eb8 --- /dev/null +++ b/rag_tutorials/corrective_rag/requirements.txt @@ -0,0 +1,19 @@ +# Core dependencies +langchain==0.3.12 +langgraph==0.2.53 +qdrant-client==1.12.1 +langchain-openai==0.2.14 +langchain-anthropic==0.3.0 +tavily-python==0.5.0 +langchain-community==0.3.12 +langchain-core==0.3.28 +streamlit==1.41.1 +tenacity==8.5.0 + +anthropic>=0.7.0 +openai>=1.12.0 +tiktoken>=0.6.0 +pydantic>=2.0.0 +numpy>=1.24.0 +PyYAML>=6.0.0 +nest-asyncio>=1.5.0