forked from AllAboutAI-YT/easy-local-rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlocalrag_no_rewrite.py
110 lines (92 loc) · 4.36 KB
/
localrag_no_rewrite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import ollama
import os
from openai import OpenAI
import argparse
# ANSI escape codes for colors
PINK = '\033[95m'
CYAN = '\033[96m'
YELLOW = '\033[93m'
NEON_GREEN = '\033[92m'
RESET_COLOR = '\033[0m'
# Function to open a file and return its contents as a string
def open_file(filepath):
with open(filepath, 'r', encoding='utf-8') as infile:
return infile.read()
# Function to get relevant context from the vault based on user input
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
if vault_embeddings.nelement() == 0: # Check if the tensor has any elements
return []
# Encode the rewritten input
input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"]
# Compute cosine similarity between the input and vault embeddings
cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings)
# Adjust top_k if it's greater than the number of available scores
top_k = min(top_k, len(cos_scores))
# Sort the scores and get the top-k indices
top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
# Get the corresponding context from the vault
relevant_context = [vault_content[idx].strip() for idx in top_indices]
return relevant_context
# Function to interact with the Ollama model
def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history):
# Get relevant context from the vault
relevant_context = get_relevant_context(user_input, vault_embeddings_tensor, vault_content, top_k=3)
if relevant_context:
# Convert list to a single string with newlines between items
context_str = "\n".join(relevant_context)
print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR)
else:
print(CYAN + "No relevant context found." + RESET_COLOR)
# Prepare the user's input by concatenating it with the relevant context
user_input_with_context = user_input
if relevant_context:
user_input_with_context = context_str + "\n\n" + user_input
# Append the user's input to the conversation history
conversation_history.append({"role": "user", "content": user_input_with_context})
# Create a message history including the system message and the conversation history
messages = [
{"role": "system", "content": system_message},
*conversation_history
]
# Send the completion request to the Ollama model
response = client.chat.completions.create(
model=ollama_model,
messages=messages
)
# Append the model's response to the conversation history
conversation_history.append({"role": "assistant", "content": response.choices[0].message.content})
# Return the content of the response from the model
return response.choices[0].message.content
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Ollama Chat")
parser.add_argument("--model", default="dolphin-llama3", help="Ollama model to use (default: llama3)")
args = parser.parse_args()
# Configuration for the Ollama API client
client = OpenAI(
base_url='http://localhost:11434/v1',
api_key='dolphin-llama3'
)
# Load the vault content
vault_content = []
if os.path.exists("vault.txt"):
with open("vault.txt", "r", encoding='utf-8') as vault_file:
vault_content = vault_file.readlines()
# Generate embeddings for the vault content using Ollama
vault_embeddings = []
for content in vault_content:
response = ollama.embeddings(model='mxbai-embed-large', prompt=content)
vault_embeddings.append(response["embedding"])
# Convert to tensor and print embeddings
vault_embeddings_tensor = torch.tensor(vault_embeddings)
print("Embeddings for each line in the vault:")
print(vault_embeddings_tensor)
# Conversation loop
conversation_history = []
system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text"
while True:
user_input = input(YELLOW + "Ask a question about your documents (or type 'quit' to exit): " + RESET_COLOR)
if user_input.lower() == 'quit':
break
response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, args.model, conversation_history)
print(NEON_GREEN + "Response: \n\n" + response + RESET_COLOR)