-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
212 lines (167 loc) · 6.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import chainlit as cl
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import CTransformers
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
prompt_template = """Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
ALWAYS return a "SOURCES" part in your answer.
The "SOURCES" part should be a reference to the source of the document from which you got your answer.
The example of your response should be:
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
"""
Prompt template for QA retrieval for each vectorstore
"""
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
return prompt
def create_retrieval_qa_chain(llm, prompt, db):
"""
Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database.
This function initializes a RetrievalQA object with a specific chain type and configurations,
and returns this QA chain. The retriever is set up to return the top 3 results (k=3).
Args:
llm (any): The language model to be used in the RetrievalQA.
prompt (str): The prompt to be used in the chain type.
db (any): The database to be used as the retriever.
Returns:
RetrievalQA: The initialized QA chain.
"""
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True,
chain_type_kwargs={"prompt": prompt},
)
return qa_chain
def load_model(
model_path="model/llama-2-7b-chat.ggmlv3.q8_0.bin",
model_type="llama",
max_new_tokens=512,
temperature=0.7,
):
"""
Load a locally downloaded model.
Parameters:
model_path (str): The path to the model to be loaded.
model_type (str): The type of the model.
max_new_tokens (int): The maximum number of new tokens for the model.
temperature (float): The temperature parameter for the model.
Returns:
CTransformers: The loaded model.
Raises:
FileNotFoundError: If the model file does not exist.
SomeOtherException: If the model file is corrupt.
"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"No model file found at {model_path}")
# Additional error handling could be added here for corrupt files, etc.
llm = CTransformers(
model=model_path,
model_type=model_type,
max_new_tokens=max_new_tokens, # type: ignore
temperature=temperature, # type: ignore
)
return llm
def create_retrieval_qa_bot(
model_name="sentence-transformers/all-MiniLM-L6-v2",
persist_dir="./db",
device="cpu",
):
"""
This function creates a retrieval-based question-answering bot.
Parameters:
model_name (str): The name of the model to be used for embeddings.
persist_dir (str): The directory to persist the database.
device (str): The device to run the model on (e.g., 'cpu', 'cuda').
Returns:
RetrievalQA: The retrieval-based question-answering bot.
Raises:
FileNotFoundError: If the persist directory does not exist.
SomeOtherException: If there is an issue with loading the embeddings or the model.
"""
if not os.path.exists(persist_dir):
raise FileNotFoundError(f"No directory found at {persist_dir}")
try:
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={"device": device},
)
except Exception as e:
raise Exception(
f"Failed to load embeddings with model name {model_name}: {str(e)}"
)
db = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
try:
llm = load_model() # Assuming this function exists and works as expected
except Exception as e:
raise Exception(f"Failed to load model: {str(e)}")
qa_prompt = (
set_custom_prompt()
) # Assuming this function exists and works as expected
try:
qa = create_retrieval_qa_chain(
llm=llm, prompt=qa_prompt, db=db
) # Assuming this function exists and works as expected
except Exception as e:
raise Exception(f"Failed to create retrieval QA chain: {str(e)}")
return qa
def retrieve_bot_answer(query):
"""
Retrieves the answer to a given query using a QA bot.
This function creates an instance of a QA bot, passes the query to it,
and returns the bot's response.
Args:
query (str): The question to be answered by the QA bot.
Returns:
dict: The QA bot's response, typically a dictionary with response details.
"""
qa_bot_instance = create_retrieval_qa_bot()
bot_response = qa_bot_instance({"query": query})
return bot_response
@cl.on_chat_start
async def initialize_bot():
"""
Initializes the bot when a new chat starts.
This asynchronous function creates a new instance of the retrieval QA bot,
sends a welcome message, and stores the bot instance in the user's session.
"""
qa_chain = create_retrieval_qa_bot()
welcome_message = cl.Message(content="Starting the bot...")
await welcome_message.send()
welcome_message.content = (
"Hi, Welcome to Chat With Documents using Llama2 and LangChain."
)
await welcome_message.update()
cl.user_session.set("chain", qa_chain)
@cl.on_message
async def process_chat_message(message):
"""
Processes incoming chat messages.
This asynchronous function retrieves the QA bot instance from the user's session,
sets up a callback handler for the bot's response, and executes the bot's
call method with the given message and callback. The bot's answer and source
documents are then extracted from the response.
"""
qa_chain = cl.user_session.get("chain")
callback_handler = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
callback_handler.answer_reached = True
response = await qa_chain.acall(message, callbacks=[callback_handler])
bot_answer = response["result"]
source_documents = response["source_documents"]
if source_documents:
bot_answer += f"\nSources:" + str(source_documents)
else:
bot_answer += "\nNo sources found"
await cl.Message(content=bot_answer).send()