-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
48 lines (39 loc) · 1.58 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
'''
===========================================
Module: Util functions
===========================================
'''
import box
import yaml
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from src.prompts import qa_template
from src.llm import build_llm
# Import config vars
with open('config/config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))
def set_qa_prompt():
"""
Prompt template for QA retrieval for each vectorstore
"""
prompt = PromptTemplate(template=qa_template,
input_variables=['context', 'question'])
return prompt
def build_retrieval_qa(llm, prompt, vectordb):
dbqa = RetrievalQA.from_chain_type(llm=llm,
chain_type='stuff',
retriever=vectordb.as_retriever(search_kwargs={'k': cfg.VECTOR_COUNT}),
return_source_documents=cfg.RETURN_SOURCE_DOCUMENTS,
chain_type_kwargs={'prompt': prompt}
)
return dbqa
def setup_dbqa():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'})
vectordb = FAISS.load_local(cfg.DB_FAISS_PATH, embeddings)
llm = build_llm()
qa_prompt = set_qa_prompt()
dbqa = build_retrieval_qa(llm, qa_prompt, vectordb)
return dbqa