forked from AllAboutAI-YT/easy-local-rag
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6ceabab
commit 8c7bca2
Showing
3 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import imaplib | ||
import email | ||
from email import policy | ||
from email.parser import BytesParser | ||
from datetime import datetime, timedelta | ||
import os | ||
import re | ||
import argparse | ||
from bs4 import BeautifulSoup | ||
import lxml | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() # Load environment variables from .env file | ||
|
||
def chunk_text(text, max_length=1000): | ||
# Normalize Unicode characters to the closest ASCII representation | ||
text = text.encode('ascii', 'ignore').decode('ascii') | ||
|
||
# Remove sequences of '>' used in email threads | ||
text = re.sub(r'\s*(?:>\s*){2,}', ' ', text) | ||
|
||
# Remove sequences of dashes, underscores, or non-breaking spaces | ||
text = re.sub(r'-{3,}', ' ', text) | ||
text = re.sub(r'_{3,}', ' ', text) | ||
text = re.sub(r'\s{2,}', ' ', text) # Collapse multiple spaces into one | ||
|
||
# Replace URLs with a single space, or remove them | ||
text = re.sub(r'https?://\S+|www\.\S+', '', text) | ||
|
||
# Normalize whitespace to single spaces, strip leading/trailing whitespace | ||
text = re.sub(r'\s+', ' ', text).strip() | ||
|
||
# Split text into sentences while preserving punctuation | ||
sentences = re.split(r'(?<=[.!?]) +', text) | ||
chunks = [] | ||
current_chunk = "" | ||
|
||
for sentence in sentences: | ||
if len(current_chunk) + len(sentence) + 1 < max_length: | ||
current_chunk += (sentence + " ").strip() | ||
else: | ||
chunks.append(current_chunk) | ||
current_chunk = sentence + " " | ||
if current_chunk: | ||
chunks.append(current_chunk) | ||
|
||
return chunks | ||
|
||
def save_chunks_to_vault(chunks): | ||
vault_path = "vault.txt" | ||
with open(vault_path, "a", encoding="utf-8") as vault_file: | ||
for chunk in chunks: | ||
vault_file.write(chunk.strip() + "\n") | ||
|
||
def get_text_from_html(html_content): | ||
soup = BeautifulSoup(html_content, 'lxml') | ||
return soup.get_text() | ||
|
||
def save_plain_text_content(email_bytes, email_id): | ||
msg = BytesParser(policy=policy.default).parsebytes(email_bytes) | ||
text_content = "" | ||
if msg.is_multipart(): | ||
for part in msg.walk(): | ||
if part.get_content_type() == 'text/plain': | ||
text_content += part.get_payload(decode=True).decode(part.get_content_charset('utf-8')) | ||
elif part.get_content_type() == 'text/html': | ||
html_content = part.get_payload(decode=True).decode(part.get_content_charset('utf-8')) | ||
text_content += get_text_from_html(html_content) | ||
else: | ||
if msg.get_content_type() == 'text/plain': | ||
text_content = msg.get_payload(decode=True).decode(msg.get_content_charset('utf-8')) | ||
elif msg.get_content_type() == 'text/html': | ||
text_content = get_text_from_html(msg.get_payload(decode=True).decode(msg.get_content_charset('utf-8'))) | ||
|
||
chunks = chunk_text(text_content) | ||
save_chunks_to_vault(chunks) | ||
return text_content | ||
|
||
def search_and_process_emails(imap_client, email_source, search_keyword, start_date, end_date): | ||
search_criteria = 'ALL' | ||
if start_date and end_date: | ||
search_criteria = f'(SINCE "{start_date}" BEFORE "{end_date}")' | ||
if search_keyword: | ||
search_criteria += f' BODY "{search_keyword}"' # Ensure the correct combination of conditions | ||
|
||
print(f"Using search criteria for {email_source}: {search_criteria}") | ||
typ, data = imap_client.search(None, search_criteria) | ||
if typ == 'OK': | ||
email_ids = data[0].split() | ||
print(f"Found {len(email_ids)} emails matching criteria in {email_source}.") | ||
|
||
for num in email_ids: | ||
typ, email_data = imap_client.fetch(num, '(RFC822)') | ||
if typ == 'OK': | ||
email_id = num.decode('utf-8') | ||
print(f"Downloading and processing email ID: {email_id} from {email_source}") | ||
save_plain_text_content(email_data[0][1], email_id) | ||
else: | ||
print(f"Failed to fetch email ID: {num.decode('utf-8')} from {email_source}") | ||
else: | ||
print(f"Failed to find emails with given criteria in {email_source}. No emails found.") | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Search and process emails based on optional keyword and date range.") | ||
parser.add_argument("--keyword", help="The keyword to search for in the email bodies.", default="") | ||
parser.add_argument("--startdate", help="Start date in DD.MM.YYYY format.", required=False) | ||
parser.add_argument("--enddate", help="End date in DD.MM.YYYY format.", required=False) | ||
args = parser.parse_args() | ||
|
||
start_date = None | ||
end_date = None | ||
|
||
# Check if both start and end dates are provided and valid | ||
if args.startdate and args.enddate: | ||
try: | ||
start_date = datetime.strptime(args.startdate, "%d.%m.%Y").strftime("%d-%b-%Y") | ||
end_date = datetime.strptime(args.enddate, "%d.%m.%Y").strftime("%d-%b-%Y") | ||
except ValueError as e: | ||
print(f"Error: Date format is incorrect. Please use DD.MM.YYYY format. Details: {e}") | ||
return | ||
elif args.startdate or args.enddate: | ||
print("Both start date and end date must be provided together.") | ||
return | ||
|
||
# Retrieve email credentials from environment variables | ||
gmail_username = os.getenv('GMAIL_USERNAME') | ||
gmail_password = os.getenv('GMAIL_PASSWORD') | ||
outlook_username = os.getenv('OUTLOOK_USERNAME') | ||
outlook_password = os.getenv('OUTLOOK_PASSWORD') | ||
|
||
# Connect to Gmail's IMAP server | ||
M = imaplib.IMAP4_SSL('imap.gmail.com') | ||
M.login(gmail_username, gmail_password) | ||
M.select('inbox') | ||
|
||
# Connect to Outlook IMAP server | ||
H = imaplib.IMAP4_SSL('imap-mail.outlook.com') | ||
H.login(outlook_username, outlook_password) | ||
H.select('inbox') | ||
|
||
# Search and process emails from Gmail and Outlook | ||
search_and_process_emails(M, "Gmail", args.keyword, start_date, end_date) | ||
search_and_process_emails(H, "Outlook", args.keyword, start_date, end_date) | ||
|
||
M.logout() | ||
H.logout() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
vault_file: "vault.txt" | ||
embeddings_file: "vault_embeddings.json" | ||
ollama_model: "llama3" | ||
top_k: 7 | ||
system_message: "You are a helpful assistant that is an expert at extracting the most useful information from a given text" | ||
|
||
ollama_api: | ||
base_url: "http://localhost:11434/v1" | ||
api_key: "llama3" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import torch | ||
import ollama | ||
import os | ||
import json | ||
from openai import OpenAI | ||
import argparse | ||
import yaml | ||
|
||
# ANSI escape codes for colors | ||
PINK = '\033[95m' | ||
CYAN = '\033[96m' | ||
YELLOW = '\033[93m' | ||
NEON_GREEN = '\033[92m' | ||
RESET_COLOR = '\033[0m' | ||
|
||
def load_config(config_file): | ||
print("Loading configuration...") | ||
try: | ||
with open(config_file, 'r') as file: | ||
return yaml.safe_load(file) | ||
except FileNotFoundError: | ||
print(f"Configuration file '{config_file}' not found.") | ||
exit(1) | ||
|
||
def open_file(filepath): | ||
print("Opening file...") | ||
try: | ||
with open(filepath, 'r', encoding='utf-8') as infile: | ||
return infile.read() | ||
except FileNotFoundError: | ||
print(f"File '{filepath}' not found.") | ||
return None | ||
|
||
def load_or_generate_embeddings(vault_content, embeddings_file): | ||
if os.path.exists(embeddings_file): | ||
print(f"Loading embeddings from '{embeddings_file}'...") | ||
try: | ||
with open(embeddings_file, "r", encoding="utf-8") as file: | ||
return torch.tensor(json.load(file)) | ||
except json.JSONDecodeError: | ||
print(f"Invalid JSON format in embeddings file '{embeddings_file}'.") | ||
embeddings = [] | ||
else: | ||
print(f"No embeddings found. Generating new embeddings...") | ||
embeddings = generate_embeddings(vault_content) | ||
save_embeddings(embeddings, embeddings_file) | ||
return torch.tensor(embeddings) | ||
|
||
def generate_embeddings(vault_content): | ||
print("Generating embeddings...") | ||
embeddings = [] | ||
for content in vault_content: | ||
try: | ||
response = ollama.embeddings(model='mxbai-embed-large', prompt=content) | ||
embeddings.append(response["embedding"]) | ||
except Exception as e: | ||
print(f"Error generating embeddings: {str(e)}") | ||
return embeddings | ||
|
||
def save_embeddings(embeddings, embeddings_file): | ||
print(f"Saving embeddings to '{embeddings_file}'...") | ||
try: | ||
with open(embeddings_file, "w", encoding="utf-8") as file: | ||
json.dump(embeddings, file) | ||
except Exception as e: | ||
print(f"Error saving embeddings: {str(e)}") | ||
|
||
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k): | ||
print("Retrieving relevant context...") | ||
if vault_embeddings.nelement() == 0: | ||
return [] | ||
try: | ||
input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"] | ||
cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings) | ||
top_k = min(top_k, len(cos_scores)) | ||
top_indices = torch.topk(cos_scores, k=top_k)[1].tolist() | ||
return [vault_content[idx].strip() for idx in top_indices] | ||
except Exception as e: | ||
print(f"Error getting relevant context: {str(e)}") | ||
return [] | ||
|
||
def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history, top_k, client): | ||
relevant_context = get_relevant_context(user_input, vault_embeddings, vault_content, top_k) | ||
if relevant_context: | ||
context_str = "\n".join(relevant_context) | ||
print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR) | ||
else: | ||
print("No relevant context found.") | ||
|
||
user_input_with_context = user_input | ||
if relevant_context: | ||
user_input_with_context = context_str + "\n\n" + user_input | ||
|
||
conversation_history.append({"role": "user", "content": user_input_with_context}) | ||
messages = [{"role": "system", "content": system_message}, *conversation_history] | ||
|
||
try: | ||
response = client.chat.completions.create( | ||
model=ollama_model, | ||
messages=messages | ||
) | ||
conversation_history.append({"role": "assistant", "content": response.choices[0].message.content}) | ||
return response.choices[0].message.content | ||
except Exception as e: | ||
print(f"Error in Ollama chat: {str(e)}") | ||
return "An error occurred while processing your request." | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Ollama Chat") | ||
parser.add_argument("--config", default="config.yaml", help="Path to the configuration file") | ||
parser.add_argument("--clear-cache", action="store_true", help="Clear the embeddings cache") | ||
parser.add_argument("--model", help="Model to use for embeddings and responses") | ||
|
||
args = parser.parse_args() | ||
|
||
config = load_config(args.config) | ||
|
||
if args.clear_cache and os.path.exists(config["embeddings_file"]): | ||
print(f"Clearing embeddings cache at '{config['embeddings_file']}'...") | ||
os.remove(config["embeddings_file"]) | ||
|
||
if args.model: | ||
config["ollama_model"] = args.model | ||
|
||
vault_content = [] | ||
if os.path.exists(config["vault_file"]): | ||
print(f"Loading content from vault '{config['vault_file']}'...") | ||
with open(config["vault_file"], "r", encoding='utf-8') as vault_file: | ||
vault_content = vault_file.readlines() | ||
|
||
vault_embeddings_tensor = load_or_generate_embeddings(vault_content, config["embeddings_file"]) | ||
|
||
client = OpenAI( | ||
base_url=config["ollama_api"]["base_url"], | ||
api_key=config["ollama_api"]["api_key"] | ||
) | ||
|
||
conversation_history = [] | ||
system_message = config["system_message"] | ||
|
||
while True: | ||
user_input = input(YELLOW + "Ask a question about your documents (or type 'quit' to exit): " + RESET_COLOR) | ||
if user_input.lower() == 'quit': | ||
break | ||
response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, config["ollama_model"], conversation_history, config["top_k"], client) | ||
print(NEON_GREEN + "Response: \n\n" + response + RESET_COLOR) | ||
|
||
if __name__ == "__main__": | ||
main() |