Skip to content

Commit

Permalink
Update src files
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethleungty committed Jul 6, 2023
1 parent 5e930b2 commit 3488725
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 149 deletions.
Binary file added data/manu-20f-2022-09-24.pdf
Binary file not shown.
54 changes: 54 additions & 0 deletions db_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# =========================
# Vector DB Build
# Author: Kenneth Leung
# =========================
import box
import yaml
from langchain.vectorstores import Chroma, FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
# from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings

# Import config vars
with open('config/config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))

# See for more info: https://huggingface.co/hkunlp/instructor-xl
# EMBED_MODEL = 'hkunlp/instructor-large' # or 'hkunlp/instructor-xl'

# Build vector database
def build_db(vectorstore='FAISS'):
loader = DirectoryLoader(cfg.DATA_PATH,
glob="*.pdf",
loader_cls=PyPDFLoader)
documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=cfg.CHUNK_SIZE,
chunk_overlap=cfg.CHUNK_OVERLAP)
texts = text_splitter.split_documents(documents)
# embedding = HuggingFaceInstructEmbeddings(model_name=EMBED_MODEL,
# model_kwargs={"device": 'cuda}
# )
# model_name = "sentence-transformers/all-mpnet-base-v2"

model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {'device': 'cpu'}

embeddings = HuggingFaceEmbeddings(model_name=model_name,
model_kwargs=model_kwargs)
# Build specific DB
if vectorstore == 'Chroma':
vectordb = Chroma.from_documents(documents=texts,
embedding=embeddings,
persist_directory=cfg.DB_CHROMA_PATH)
vectordb.persist()
elif vectorstore == 'FAISS':
vectorstore = FAISS.from_documents(texts, embeddings)
vectorstore.save_local(cfg.DB_FAISS_PATH)
print('FAISS Vectorstore - Build Complete')
else:
raise ValueError('Error in DB selection')

if __name__ == "__main__":
build_db()
15 changes: 3 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.poetry]
name = "GPTer"
name = "Open-Source LLM - CPU Inference"
version = "0.1.0"
description = ""
authors = ["Kenneth Leung, Jack Yang"]
authors = ["Kenneth Leung]
license = ""
readme = "README.md"

Expand All @@ -11,27 +11,18 @@ start = "main:app"

[tool.poetry.dependencies]
python = "^3.9"
boto3 = "1.26.149"
faiss-cpu = "1.7.4"
Flask = "2.3.2"
langchain = "0.0.193"
openai = "0.27.6"
langchain = "0.0.225"
pypdf = "3.8.1"
python-dotenv = "1.0.0"
python-box = "7.0.1"
sentence-transformers = "2.2.2"
sagemaker = "2.163.0"
slack_bolt = "1.18.0"
slack_sdk = "3.21.3"
transformers = "4.29.0"
ipykernel = "^6.23.1"
ctransformers = "^0.2.5"
fastapi = "^0.96.0"
uvicorn = "^0.22.0"

[tool.poetry.group.dev.dependencies]
pylint = "^2.17.4"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
34 changes: 0 additions & 34 deletions src/functions.py

This file was deleted.

75 changes: 6 additions & 69 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from langchain import SagemakerEndpoint
from langchain.llms import HuggingFacePipeline, GPT4All, CTransformers
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
'''
===========================================
Module: Open-source LLM Setup
===========================================
'''
from langchain.llms import CTransformers
from dotenv import find_dotenv, load_dotenv
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
from typing import Dict


import box
import json
import torch
import os
import yaml

# Load environment variables from .env file
Expand All @@ -30,61 +24,4 @@ def build_llm():
'temperature': cfg.TEMPERATURE}
)

# # HuggingFace Models
# tokenizer = LlamaTokenizer.from_pretrained("TheBloke/wizardLM-7B-HF")
# model = LlamaForCausalLM.from_pretrained("TheBloke/wizardLM-7B-HF",
# load_in_8bit=True,
# device_map='auto',
# # torch_dtype=torch.float16,
# low_cpu_mem_usage=True
# )
# pipe = pipeline(
# "text-generation",
# model=model,
# tokenizer=tokenizer,
# max_length=1024,
# temperature=0,
# top_p=0.95,
# repetition_penalty=1.15
# )
# llm = HuggingFacePipeline(pipeline=pipe)

# # OpenAI API
# llm = ChatOpenAI(
# model_name='gpt-3.5-turbo',
# temperature=0,
# openai_api_key=os.environ['OPENAI_API_KEY'],
# max_tokens=256
# )

# # Local GPT4ALL
# model_path = 'bin/ggml-gpt4all-j-v1.3-groovy.bin'
# # callbacks = [StreamingStdOutCallbackHandler()]
# llm = GPT4All(model=model_path,
# # backend='gptj',
# # callbacks=callbacks,
# verbose=True)

# # AWS SageMaker endpoint
# class ContentHandler(LLMContentHandler):
# content_type = "application/json"
# accepts = "application/json"

# def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
# input_str = json.dumps({prompt: prompt, **model_kwargs})
# return input_str.encode('utf-8')

# def transform_output(self, output: bytes) -> str:
# response_json = json.loads(output.read().decode("utf-8"))
# return response_json[0]["generated_text"]

# content_handler = ContentHandler()

# llm = SagemakerEndpoint(
# endpoint_name=cfg.AWS_SAGEMAKER_ENDPOINT_NAME,
# # credentials_profile_name="credentials-profile-name",
# region_name=cfg.AWS_REGION,
# model_kwargs={"temperature":1e-10},
# content_handler=content_handler)

return llm
26 changes: 3 additions & 23 deletions src/prompts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
'''
===========================================
Module: Prompts collection
Module: Prompts collection
===========================================
'''
# Note that the spacing and indentation of the prompt template is important for MPT-7B-Instruct, as it is highly sensitive to these
# whitespace changes. For example, it could have problems generating a summay from the pieces of context
# whitespace changes. For example, it could have problems generating a summary from the pieces of context

mpt_7b_qa_template = """You are an expert HR assistant. Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Expand All @@ -13,24 +14,3 @@
Question: {question}
Helpful detailed answer:"""


qa_system_template_prefix = """
You are an assistant to a human, powered by a large language model trained by OpenAI.
You are designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, you are able to generate human-like text based on the input you receive, allowing you to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
You are constantly learning and improving, and your capabilities are constantly evolving.
You are able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions.
You have access to some personalized information provided by the human in the Context section below.
"""


qa_system_template_main = """Use the following pieces of information to answer the human's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Helpful answer:"""
33 changes: 22 additions & 11 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
'''
===========================================
Module: Util functions
===========================================
'''
import box
import yaml
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate)

from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from src.prompts import mpt_7b_qa_template
from src.utils import set_qa_prompt, build_retrieval_qa
from src.llm import build_llm

# Import config vars
with open('config/config.yml', 'r', encoding='utf8') as ymlfile:
Expand All @@ -17,13 +23,6 @@ def set_qa_prompt():
"""
Prompt template for QA retrieval for each vectorstore
"""
# messages = [
# SystemMessagePromptTemplate.from_template(qa_system_template_prefix),
# SystemMessagePromptTemplate.from_template(qa_system_template_main),
# HumanMessagePromptTemplate.from_template('{question}')
# ]
# qa_prompt = ChatPromptTemplate.from_messages(messages)

prompt = PromptTemplate(template=mpt_7b_qa_template,
input_variables=['context', 'question'])
return prompt
Expand All @@ -37,3 +36,15 @@ def build_retrieval_qa(llm, prompt, vectordb):
chain_type_kwargs={'prompt': prompt}
)
return dbqa


def setup_dbqa():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'})
vectordb = FAISS.load_local(cfg.DB_FAISS_PATH, embeddings)

llm = build_llm()
qa_prompt = set_qa_prompt()
dbqa = build_retrieval_qa(llm, qa_prompt, vectordb)

return dbqa

0 comments on commit 3488725

Please sign in to comment.