From 6ca2a9a37d15dc2d785c3b8116147901ffcdf3cc Mon Sep 17 00:00:00 2001 From: eren23 Date: Fri, 7 Apr 2023 18:28:09 +0200 Subject: [PATCH] a flag to restart the context mid conversation --- knowledgegpt/__init__.py | 2 +- knowledgegpt/extractors/base_extractor.py | 5 +++-- knowledgegpt/utils/utils_completion.py | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/knowledgegpt/__init__.py b/knowledgegpt/__init__.py index ae4390f..a15fb29 100644 --- a/knowledgegpt/__init__.py +++ b/knowledgegpt/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.6b" +__version__ = "0.0.7b" from .extractors.yt_subs_extractor import YTSubsExtractor from .extractors.yt_audio_extractor import YoutubeAudioExtractor diff --git a/knowledgegpt/extractors/base_extractor.py b/knowledgegpt/extractors/base_extractor.py index e186796..c4b9ff0 100644 --- a/knowledgegpt/extractors/base_extractor.py +++ b/knowledgegpt/extractors/base_extractor.py @@ -63,7 +63,7 @@ def load_embeddings_indexes(self): if self.df is None: self.df = pd.read_csv(self.index_path + "/df.csv") - def extract(self, query, max_tokens, load_index=False) -> tuple[str, str, list]: + def extract(self, query, max_tokens, load_index=False, context_restarter=False) -> tuple[str, str, list]: """ param query: Query to answer param max_tokens: Maximum number of tokens to generate @@ -85,7 +85,8 @@ def extract(self, query, max_tokens, load_index=False) -> tuple[str, str, list]: messages=self.messages, max_tokens=max_tokens, index_type=self.index_type, - prompt_template=self.prompt_template + prompt_template=self.prompt_template, + context_restarter=context_restarter ) if not self.verbose: print("all_done!") diff --git a/knowledgegpt/utils/utils_completion.py b/knowledgegpt/utils/utils_completion.py index 2a301ef..f447be3 100644 --- a/knowledgegpt/utils/utils_completion.py +++ b/knowledgegpt/utils/utils_completion.py @@ -37,7 +37,8 @@ def answer_query_with_context( messages: list = None, index_type: str = "basic", max_tokens=1000, - prompt_template=None + prompt_template=None, + context_restarter: bool = False ) -> str: """ Answer a query using the provided context. @@ -55,7 +56,7 @@ def answer_query_with_context( """ - if len(messages) < 3 or not is_turbo: + if len(messages) < 3 or not is_turbo or context_restarter: prompt = construct_prompt( verbose=verbose, question=query,