-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
1,953 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# 导入必要的库 | ||
import os | ||
import sys | ||
from LLM import InternLM_LLM | ||
import requests | ||
import gradio as gr | ||
from BCEmbedding.tools.langchain import BCERerank | ||
from langchain.prompts import PromptTemplate | ||
from langchain_community.vectorstores import FAISS | ||
from langchain.embeddings import HuggingFaceEmbeddings | ||
from langchain_community.vectorstores.utils import DistanceStrategy | ||
from langchain.retrievers import ContextualCompressionRetriever | ||
from openxlab.model import download | ||
|
||
__import__('pysqlite3') | ||
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | ||
#download(model_repo='OpenLMLab/internlm2-chat-7b', output='internlm2-chat-7b') | ||
|
||
os.makedirs('model', exist_ok=True) | ||
|
||
def download_file_from_google_drive(url, destination): | ||
response = requests.get(url, stream=True) | ||
with open(destination, "wb") as file: | ||
for chunk in response.iter_content(chunk_size=1024): | ||
if chunk: | ||
file.write(chunk) | ||
|
||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' | ||
os.environ['HUGGINGFACE_TOKEN'] = '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' | ||
os.system('huggingface-cli login $HUGGINGFACE_TOKEN') | ||
os.system('huggingface-cli download --resume-download maidalun1020/bce-embedding-base_v1 --local-dir model/bce-embedding-base_v1') | ||
os.system('huggingface-cli download --resume-download maidalun1020/bce-reranker-base_v1 --local-dir model/bce-reranker-base_v1') | ||
|
||
|
||
def load_chain(): | ||
# 加载问答链 | ||
# 加载本地索引 | ||
embedding_model_name = './model/bce-embedding-base_v1' # 'maidalun1020/bce-embedding-base_v1' | ||
embedding_model_kwargs = {'device': 'cuda:0'} | ||
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True, 'show_progress_bar': False} | ||
embeddings = HuggingFaceEmbeddings( | ||
model_name=embedding_model_name, | ||
model_kwargs=embedding_model_kwargs, | ||
encode_kwargs=embedding_encode_kwargs | ||
) | ||
loaded_index = FAISS.load_local('./faiss_index', embeddings) | ||
# 构建检索器 | ||
reranker_args = {'model': './model/bce-reranker-base_v1', 'top_n': 50, 'device': 'cuda:0'} | ||
reranker = BCERerank(**reranker_args) | ||
retriever = loaded_index.as_retriever(search_type="similarity", search_kwargs={"score_threshold": 0.3, "k": 50}) | ||
compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever) | ||
llm = InternLM_LLM(model_path = "internlm2-chat-7b") | ||
|
||
template = """使用以下上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。 | ||
问题: {question} | ||
可参考的上下文: | ||
··· | ||
{context} | ||
··· | ||
如果给定的上下文无法让你做出回答,请回答你不知道。 | ||
有用的回答:""" | ||
|
||
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"], | ||
template=template) | ||
|
||
# 运行 chain | ||
from langchain.chains import RetrievalQA | ||
|
||
qa_chain = RetrievalQA.from_chain_type(llm, | ||
retriever=retriever, | ||
return_source_documents=True, | ||
chain_type_kwargs={"prompt":QA_CHAIN_PROMPT}) | ||
|
||
return qa_chain | ||
|
||
class Model_center(): | ||
""" | ||
存储问答 Chain 的对象 | ||
""" | ||
def __init__(self): | ||
self.chain = load_chain() | ||
|
||
def qa_chain_self_answer(self, question: str, chat_history: list = []): | ||
""" | ||
调用不带历史记录的问答链进行回答 | ||
""" | ||
if question == None or len(question) < 1: | ||
return "", chat_history | ||
try: | ||
chat_history.append( | ||
(question, self.chain({"query": question})["result"])) | ||
return "", chat_history | ||
except Exception as e: | ||
return e, chat_history | ||
|
||
|
||
model_center = Model_center() | ||
|
||
block = gr.Blocks() | ||
with block as demo: | ||
with gr.Row(equal_height=True): | ||
with gr.Column(scale=15): | ||
gr.Markdown("""<h1><center>InternLM</center></h1> | ||
<center>你的专属量刑助手</center> | ||
""") | ||
# gr.Image(value=LOGO_PATH, scale=1, min_width=10,show_label=False, show_download_button=False) | ||
|
||
with gr.Row(): | ||
with gr.Column(scale=4): | ||
chatbot = gr.Chatbot(height=450, show_copy_button=True) | ||
# 创建一个文本框组件,用于输入 prompt。 | ||
msg = gr.Textbox(label="Prompt/问题") | ||
|
||
with gr.Row(): | ||
# 创建提交按钮。 | ||
db_wo_his_btn = gr.Button("Chat") | ||
with gr.Row(): | ||
# 创建一个清除按钮,用于清除聊天机器人组件的内容。 | ||
clear = gr.ClearButton( | ||
components=[chatbot], value="Clear console") | ||
|
||
# 设置按钮的点击事件。当点击时,调用上面定义的 qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。 | ||
db_wo_his_btn.click(model_center.qa_chain_self_answer, inputs=[ | ||
msg, chatbot], outputs=[msg, chatbot]) | ||
|
||
gr.Markdown("""提醒:<br> | ||
1. 初始化数据库时间可能较长,请耐心等待。 | ||
2. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br> | ||
""") | ||
# threads to consume the request | ||
gr.close_all() | ||
# 启动新的 Gradio 应用,设置分享功能为 True,并使用环境变量 PORT1 指定服务器端口。 | ||
# demo.launch(share=True, server_port=int(os.environ['PORT1'])) | ||
# 直接启动 | ||
demo.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# 首先导入所需第三方库 | ||
from BCEmbedding.tools.langchain import BCERerank | ||
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain_community.document_loaders import PyPDFLoader | ||
from langchain_community.vectorstores import FAISS | ||
|
||
from langchain.embeddings import HuggingFaceEmbeddings | ||
from langchain_community.vectorstores.utils import DistanceStrategy | ||
from langchain.text_splitter import MarkdownHeaderTextSplitter # markdown分割器 | ||
|
||
tar_path = "./law_data/刑法.md" | ||
|
||
with open(tar_path, 'r', encoding='utf-8') as file: | ||
loaded_text = file.read() | ||
|
||
# 准备分割的标题 | ||
headers_to_split_on = [ | ||
("#", "Header 1"), | ||
("##", "Header 2"), | ||
("###", "Header 3"), | ||
] | ||
|
||
# 文档分割器 | ||
markdown_splitter = MarkdownHeaderTextSplitter( | ||
headers_to_split_on=headers_to_split_on | ||
) | ||
|
||
# 分割文档 | ||
split_docs = markdown_splitter.split_text(loaded_text) | ||
|
||
# print(split_docs) | ||
|
||
# 构建向量数据库 | ||
embedding_model_name = './model/bce-embedding-base_v1' # 'maidalun1020/bce-embedding-base_v1' | ||
embedding_model_kwargs = {'device': 'cuda:0'} | ||
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True, 'show_progress_bar': False} | ||
embeddings = HuggingFaceEmbeddings( | ||
model_name=embedding_model_name, | ||
model_kwargs=embedding_model_kwargs, | ||
encode_kwargs=embedding_encode_kwargs | ||
) | ||
|
||
faiss_index = FAISS.from_documents(split_docs, embeddings, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) | ||
|
||
# 保存索引到磁盘 | ||
faiss_index.save_local('./faiss_index') | ||
|
||
# # 在将来需要的时候加载索引 | ||
# loaded_index = FAISS.read_index('path_to_saved_index') | ||
# retriever = loaded_index.as_retriever(search_type="similarity", search_kwargs={"score_threshold": 0.3, "k": 10}) | ||
|
Binary file not shown.
Binary file not shown.
Oops, something went wrong.