Skip to content

Commit

Permalink
deprecated topK;add json_format dict; add use_rerank
Browse files Browse the repository at this point in the history
  • Loading branch information
tea9297 committed Feb 26, 2024
1 parent f6125f1 commit 614e751
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 153 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,9 @@ dataset.json
expert.json
question_set2.txt
result.csv
test_*.py
test_*.py
.vscode/settings.json
SIMYOU.TTF
SIMYOU.cw127.pkl
SIMYOU.pkl
default.pdf
47 changes: 39 additions & 8 deletions akasha/akasha.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import akasha.prompts as prompts
import akasha.db
import datetime, traceback
import warnings
from dotenv import load_dotenv

load_dotenv(pathlib.Path().cwd() / ".env")
Expand Down Expand Up @@ -180,7 +181,7 @@ def __init__(
chunk_size: int = 1000,
model: str = "openai:gpt-3.5-turbo",
verbose: bool = False,
topK: int = 2,
topK: int = -1,
threshold: float = 0.2,
language: str = "ch",
search_type: Union[str, Callable] = "svm",
Expand Down Expand Up @@ -225,6 +226,10 @@ def __init__(
self.temperature = temperature

self.timestamp_list = []
if topK != -1:
warnings.warn(
"The 'topK' parameter is deprecated and will be removed in future versions",
DeprecationWarning)

def _set_model(self, **kwargs):
"""change model, embeddings, search_type, temperature if user use **kwargs to change them."""
Expand Down Expand Up @@ -439,18 +444,42 @@ def __init__(
chunk_size: int = 1000,
model: str = "openai:gpt-3.5-turbo",
verbose: bool = False,
topK: int = 2,
topK: int = -1,
threshold: float = 0.2,
language: str = "ch",
search_type: Union[str, Callable] = "svm",
record_exp: str = "",
system_prompt: str = "",
prompt_format_type: str = "gpt",
max_doc_len: int = 1500,
temperature: float = 0.0,
compression: bool = False,
use_chroma: bool = False,
use_rerank: bool = False,
ignore_check: bool = False,
):
"""initials of Doc_QA class
Args:
embeddings (_type_, optional): embedding model, including two types(openai and huggingface). Defaults to "openai:text-embedding-ada-002".
chunk_size (int, optional): the max length of each text segments. Defaults to 1000.
model (_type_, optional): language model. Defaults to "openai:gpt-3.5-turbo".
verbose (bool, optional): print the processing text or not. Defaults to False.
topK (int, optional): the number of documents to be selected. Defaults to 2.
threshold (float, optional): threshold of similarity for searching relavant documents. Defaults to 0.2.
language (str, optional): "ch" chinese or "en" english. Defaults to "ch".
search_type (Union[str, Callable], optional): _description_. Defaults to "svm".
record_exp (str, optional): experiment name of aiido. Defaults to "".
system_prompt (str, optional): the prompt you want llm to output in certain format. Defaults to "".
prompt_format_type (str, optional): the prompt and system prompt format for the language model, including two types(gpt and llama). Defaults to "gpt".
max_doc_len (int, optional): max total length of selected documents. Defaults to 1500.
temperature (float, optional): temperature for language model. Defaults to 0.0.
compression (bool, optional): compress the selected documents or not. Defaults to False.
use_chroma (bool, optional): use chroma db name instead of documents path to load data or not. Defaults to False.
use_rerank (bool, optional): use rerank model to re-rank the selected documents or not. Defaults to False.
ignore_check (bool, optional): speed up loading data if the chroma db is already existed. Defaults to False.
"""

super().__init__(
chunk_size,
model,
Expand All @@ -469,6 +498,8 @@ def __init__(
self.compression = compression
self.use_chroma = use_chroma
self.ignore_check = ignore_check
self.use_rerank = use_rerank
self.prompt_format_type = prompt_format_type
### set variables ###
self.logs = {}
self.model_obj = helper.handle_model(model, self.verbose,
Expand Down Expand Up @@ -530,7 +561,7 @@ def get_response(self, doc_path: Union[List[str], str], prompt: str,
self.db,
self.embeddings_obj,
self.prompt,
self.topK,
self.use_rerank,
self.threshold,
self.language,
self.search_type,
Expand All @@ -549,7 +580,7 @@ def get_response(self, doc_path: Union[List[str], str], prompt: str,
if self.system_prompt.replace(' ', '') == "":
self.system_prompt = prompts.default_doc_ask_prompt()
prod_sys_prompt, prod_prompt = prompts.format_sys_prompt(
self.system_prompt, self.prompt)
self.system_prompt, self.prompt, self.prompt_format_type)

self.response = self._ask_model(prod_sys_prompt, prod_prompt)

Expand Down Expand Up @@ -640,7 +671,7 @@ def recursive_get_response(prompt_list):
self.db,
self.embeddings_obj,
prompt,
self.topK,
self.use_rerank,
self.threshold,
self.language,
self.search_type,
Expand All @@ -656,7 +687,7 @@ def recursive_get_response(prompt_list):
self.docs = docs + pre_result
## format prompt ##
prod_sys_prompt, prod_prompt = prompts.format_sys_prompt(
self.system_prompt, prompt)
self.system_prompt, prompt, self.prompt_format_type)

response = self._ask_model(prod_sys_prompt, prod_prompt)

Expand Down Expand Up @@ -737,7 +768,7 @@ def ask_whole_file(self, file_path: str, prompt: str, **kwargs) -> str:
if self.system_prompt.replace(' ', '') == "":
self.system_prompt = prompts.default_doc_ask_prompt()
prod_sys_prompt, prod_prompt = prompts.format_sys_prompt(
self.system_prompt, self.prompt)
self.system_prompt, self.prompt, self.prompt_format_type)
self.response = self._ask_model(prod_sys_prompt, prod_prompt)

end_time = time.time()
Expand Down Expand Up @@ -804,7 +835,7 @@ def ask_self(self,
if self.system_prompt.replace(' ', '') == "":
self.system_prompt = prompts.default_doc_ask_prompt()
prod_sys_prompt, prod_prompt = prompts.format_sys_prompt(
self.system_prompt, self.prompt)
self.system_prompt, self.prompt, self.prompt_format_type)
self.response = self._ask_model(prod_sys_prompt, prod_prompt)

end_time = time.time()
Expand Down
36 changes: 17 additions & 19 deletions akasha/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,19 @@ def __init__(
chunk_size: int = 1000,
model: str = "openai:gpt-3.5-turbo",
verbose: bool = False,
topK: int = 2,
topK: int = -1,
threshold: float = 0.2,
language: str = "ch",
search_type: Union[str, Callable] = "svm",
record_exp: str = "",
system_prompt: str = "",
prompt_format_type: str = "gpt",
max_doc_len: int = 1500,
temperature: float = 0.0,
question_type: str = "fact",
question_style: str = "essay",
use_chroma: bool = False,
use_rerank: bool = False,
ignore_check: bool = False,
):
"""initials of Model_Eval class
Expand All @@ -121,6 +123,7 @@ def __init__(
**temperature (float, optional)**: temperature of llm model from 0.0 to 1.0 . Defaults to 0.0.\n
**question_style (str, optional)**: the style of question you want to generate, "essay" or "single_choice". Defaults to "essay".\n
**question_type (str, optional)**: the type of question you want to generate, "fact", "summary", "irrelevant", "compared". Defaults to "fact".\n
**use_rerank (bool, optional)**: use rerank model to re-rank the selected documents or not. Defaults to False.
"""

super().__init__(
Expand All @@ -141,7 +144,7 @@ def __init__(
self.question_type = question_type
self.question_style = question_style
self.question_num = 0

self.prompt_format_type = prompt_format_type
### set variables ###
self.logs = {}
self.model_obj = akasha.helper.handle_model(model, self.verbose,
Expand All @@ -161,6 +164,7 @@ def __init__(
self.score = {}
self.use_chroma = use_chroma
self.ignore_check = ignore_check
self.use_rerank = use_rerank

def _save_questionset(self, timestamp: str, output_file_path: str):
"""save questions and ref answers into txt file, and save the path of question set into logs
Expand Down Expand Up @@ -486,7 +490,7 @@ def _eval_get_res_fact(self, question: Union[str, list], answer: str,
prod_sys = self.system_prompt + akasha.prompts.default_doc_ask_prompt(
)
prod_sys, query_with_prompt = akasha.prompts.format_sys_prompt(
prod_sys, question)
prod_sys, question, self.prompt_format_type)
else:
prod_sys = self.system_prompt
query, ans = akasha.prompts.format_question_query(question, answer)
Expand All @@ -497,7 +501,7 @@ def _eval_get_res_fact(self, question: Union[str, list], answer: str,
self.db,
self.embeddings_obj,
query,
self.topK,
self.use_rerank,
self.threshold,
self.language,
self.search_type,
Expand Down Expand Up @@ -571,7 +575,7 @@ def _eval_get_res_summary(self, sum_doc: str, answer: str,

prompt = "請對以上文件進行摘要。"
prod_sys, query_with_prompt = akasha.prompts.format_sys_prompt(
self.system_prompt, prompt)
self.system_prompt, prompt, self.prompt_format_type)

self.docs = [
Document(page_content=sum_doc, metadata={
Expand Down Expand Up @@ -645,7 +649,7 @@ def auto_create_questionset(
choice_num: int = 4,
output_file_path: str = "",
**kwargs,
) -> (list, list):
) -> Tuple[list, list]:
"""auto create question set by llm model, each time it will randomly select a range of documents from the documents directory,
then use llm model to generate a question and answer pair, and save it into a txt file.
1.The format of "single_choice" questionset should be one line one question, and the possibles answers and questions are separate by tab(\t),
Expand Down Expand Up @@ -953,10 +957,9 @@ def optimum_combination(
embeddings_list: list = ["openai:text-embedding-ada-002"],
chunk_size_list: list = [500],
model_list: list = ["openai:gpt-3.5-turbo"],
topK_list: list = [2],
search_type_list: list = ["svm", "tfidf", "mmr"],
**kwargs,
) -> (list, list):
) -> Tuple[list, list]:
"""test all combinations of giving lists, and run auto_evaluation to find parameters of the best result.
Args:
Expand All @@ -966,7 +969,6 @@ def optimum_combination(
**embeddings_list (_type_, optional)**: list of embeddings models. Defaults to ["openai:text-embedding-ada-002"].\n
**chunk_size_list (list, optional)**: list of chunk sizes. Defaults to [500].\n
**model_list (_type_, optional)**: list of models. Defaults to ["openai:gpt-3.5-turbo"].\n
**topK_list (list, optional)**: list of topK. Defaults to [2].\n
**threshold (float, optional)**: the similarity threshold of searching. Defaults to 0.2.\n
**search_type_list (list, optional)**: list of search types, currently have "merge", "svm", "knn", "tfidf", "mmr". Defaults to ['svm','tfidf','mmr'].
Returns:
Expand All @@ -977,7 +979,7 @@ def optimum_combination(
start_time = time.time()
combinations = akasha.helper.get_all_combine(embeddings_list,
chunk_size_list,
model_list, topK_list,
model_list,
search_type_list)
progress = tqdm(len(combinations),
total=len(combinations),
Expand All @@ -991,7 +993,7 @@ def optimum_combination(
else:
bcr = 0.0

for embed, chk, mod, tK, st in combinations:
for embed, chk, mod, st in combinations:
progress.update(1)

if self.question_type.lower() == "essay":
Expand All @@ -1001,7 +1003,6 @@ def optimum_combination(
embeddings=embed,
chunk_size=chk,
model=mod,
topK=tK,
search_type=st,
)

Expand All @@ -1015,7 +1016,6 @@ def optimum_combination(
embed,
chk,
mod,
tK,
self.search_type_str,
)
else:
Expand All @@ -1025,7 +1025,6 @@ def optimum_combination(
embeddings=embed,
chunk_size=chk,
model=mod,
topK=tK,
search_type=st,
)
bcr = max(bcr, cur_correct_rate)
Expand All @@ -1035,7 +1034,6 @@ def optimum_combination(
embed,
chk,
mod,
tK,
self.search_type_str,
)
result_list.append(cur_tup)
Expand Down Expand Up @@ -1093,7 +1091,7 @@ def create_topic_questionset(
choice_num: int = 4,
output_file_path: str = "",
**kwargs,
) -> (list, list):
) -> Tuple[list, list]:
"""similar to auto_create_questionset, but it will use the topic to find the related documents and create questionset.
Args:
**doc_path (str)**: documents directory path\n
Expand Down Expand Up @@ -1157,19 +1155,19 @@ def create_topic_questionset(
self.db,
self.embeddings_obj,
topic,
99,
self.use_rerank,
self.threshold,
self.language,
self.search_type,
self.verbose,
self.model_obj,
999999,
25000,
self.logs[timestamp],
)

texts = [doc.page_content for doc in self.docs]
metadata = [doc.metadata for doc in self.docs]
print(texts)

doc_range = min(doc_range, len(texts))

progress = tqdm(total=question_num,
Expand Down
Loading

0 comments on commit 614e751

Please sign in to comment.