Skip to content

Commit

Permalink
Multiple updates
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethleungty committed Jul 6, 2023
1 parent 3e694fb commit b95e89a
Show file tree
Hide file tree
Showing 9 changed files with 2,951 additions and 62 deletions.
10 changes: 10 additions & 0 deletions config/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
RETURN_SOURCE_DOCUMENTS: True
VECTOR_COUNT: 2
CHUNK_SIZE: 500
CHUNK_OVERLAP: 50
DATA_PATH: 'data/'
DB_FAISS_PATH: 'vectorstore/db_faiss'
MODEL_TYPE: 'mpt'
MODEL_BIN_PATH: 'models/mpt-7b-instruct.ggmlv3.q8_0.bin'
MAX_NEW_TOKENS: 256
TEMPERATURE: 0.1
34 changes: 10 additions & 24 deletions db_build.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
# =========================
# Vector DB Build
# Author: Kenneth Leung
# Module: Vector DB Build
# =========================
import box
import yaml
from langchain.vectorstores import Chroma, FAISS
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
# from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings

# Import config vars
with open('config/config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))

# See for more info: https://huggingface.co/hkunlp/instructor-xl
# EMBED_MODEL = 'hkunlp/instructor-large' # or 'hkunlp/instructor-xl'

# Build vector database
def build_db(vectorstore='FAISS'):
def run_db_build():
print('Start DB Build')
loader = DirectoryLoader(cfg.DATA_PATH,
glob="*.pdf",
loader_cls=PyPDFLoader)
Expand All @@ -27,28 +24,17 @@ def build_db(vectorstore='FAISS'):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=cfg.CHUNK_SIZE,
chunk_overlap=cfg.CHUNK_OVERLAP)
texts = text_splitter.split_documents(documents)
# embedding = HuggingFaceInstructEmbeddings(model_name=EMBED_MODEL,
# model_kwargs={"device": 'cuda}
# )
# model_name = "sentence-transformers/all-mpnet-base-v2"

model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {'device': 'cpu'}

embeddings = HuggingFaceEmbeddings(model_name=model_name,
model_kwargs=model_kwargs)
# Build specific DB
if vectorstore == 'Chroma':
vectordb = Chroma.from_documents(documents=texts,
embedding=embeddings,
persist_directory=cfg.DB_CHROMA_PATH)
vectordb.persist()
elif vectorstore == 'FAISS':
vectorstore = FAISS.from_documents(texts, embeddings)
vectorstore.save_local(cfg.DB_FAISS_PATH)
print('FAISS Vectorstore - Build Complete')
else:
raise ValueError('Error in DB selection')

vectorstore = FAISS.from_documents(texts, embeddings)
vectorstore.save_local(cfg.DB_FAISS_PATH)
print('FAISS Vectorstore - Build Complete')


if __name__ == "__main__":
build_db()
run_db_build()
46 changes: 17 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
import box
import timeit
import yaml
import uvicorn
import argparse
from dotenv import find_dotenv, load_dotenv
from fastapi import FastAPI, BackgroundTasks, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from src.functions import setup_dbqa

app = FastAPI()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=['GET', 'POST', 'OPTIONS'],
allow_headers=['*'],
allow_credentials=True
)
from src.utils import setup_dbqa

# Load environment variables from .env file
load_dotenv(find_dotenv())
Expand All @@ -27,20 +13,22 @@
cfg = box.Box(yaml.safe_load(ymlfile))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('query',
type=str,
default='How much is the minimum guarantee payable by adidas?',
help='Enter the query to pass into the LLM')
args = parser.parse_args()
query = args.query

@app.get("/")
async def generate_llm_response(query: str):
# Load DBQA object
# Setup DBQA
start = timeit.default_timer()
dbqa = setup_dbqa()
response = dbqa({'query': 'How much is the minimum guarantee payable by adidas?'})
end = timeit.default_timer()
print(f"Time to load DBQA: {end - start}")

start = timeit.default_timer()
query = 'How many days of maternity leave do I have?'
response = dbqa({'query': query})
end = timeit.default_timer()
print(response)
print(f"Time to retrieve response: {end - start}")

return {"response": response}

print(response['result'])
print('='*50)
print(response['source_documents'])
print(f"Time to retrieve response: {end - start}")
Loading

0 comments on commit b95e89a

Please sign in to comment.