From d7ed470b9b75927a45c10447d7696aecf854bde7 Mon Sep 17 00:00:00 2001 From: Daedalus Date: Tue, 26 Dec 2023 18:44:55 +0100 Subject: [PATCH] Add multiple node postprocessors Multiple node postprocessors can now be used instead of just one. --- gptstonks_api/main.py | 18 ++++++++++++------ gptstonks_api/utils.py | 11 ++++++----- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/gptstonks_api/main.py b/gptstonks_api/main.py index 1370098..aa7d5ec 100644 --- a/gptstonks_api/main.py +++ b/gptstonks_api/main.py @@ -19,7 +19,10 @@ from llama_index.embeddings import OpenAIEmbedding from llama_index.embeddings.openai import OpenAIEmbeddingModelType from llama_index.llms import LangChainLLM -from llama_index.postprocessor import MetadataReplacementPostProcessor +from llama_index.postprocessor import ( + MetadataReplacementPostProcessor, + SimilarityPostprocessor, +) from openbb import obb from openbb_chat.kernels.auto_llama_index import AutoLlamaIndex @@ -158,10 +161,13 @@ def init_data(): app.AI_PREFIX = "GPTSTONKS_RESPONSE" app.python_repl_utility = PythonREPL() app.python_repl_utility.globals = globals() - app.node_postprocessor = ( - MetadataReplacementPostProcessor(target_metadata_key="extra_context") + app.node_postprocessors = ( + [ + SimilarityPostprocessor(similarity_cutoff=0.8), + MetadataReplacementPostProcessor(target_metadata_key="extra_context"), + ] if os.getenv("REMOVE_POSTPROCESSOR", None) is None - else None + else [SimilarityPostprocessor(similarity_cutoff=0.8)] ) search_tool = DuckDuckGoSearchResults(api_wrapper=DuckDuckGoSearchAPIWrapper()) wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) @@ -195,7 +201,7 @@ def init_data(): get_openbb_chat_output_executed, auto_llama_index=app.auto_llama_index, python_repl_utility=app.python_repl_utility, - node_postprocessor=app.node_postprocessor, + node_postprocessors=app.node_postprocessors, ), description=os.getenv("OPENBBCHAT_TOOL_DESCRIPTION"), return_direct=True, @@ -286,7 +292,7 @@ async def run_model_in_background(query: str, use_agent: bool, openbb_pat: str | openbbchat_output = await get_openbb_chat_output( query_str=query, auto_llama_index=app.auto_llama_index, - node_postprocessor=app.node_postprocessor, + node_postprocessors=app.node_postprocessors, ) code_str = ( openbbchat_output.response.split("```python")[1].split("```")[0] diff --git a/gptstonks_api/utils.py b/gptstonks_api/utils.py index f9f8e9d..f6c5c29 100644 --- a/gptstonks_api/utils.py +++ b/gptstonks_api/utils.py @@ -124,11 +124,12 @@ def get_keys_file(): async def get_openbb_chat_output( query_str: str, auto_llama_index: AutoLlamaIndex, - node_postprocessor: Optional[BaseNodePostprocessor] = None, + node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, ) -> str: nodes = await auto_llama_index._retriever.aretrieve(query_str) - if node_postprocessor is not None: - nodes = node_postprocessor.postprocess_nodes(nodes) + if node_postprocessors is not None: + for node_postprocessor in node_postprocessors: + nodes = node_postprocessor.postprocess_nodes(nodes) return await auto_llama_index._query_engine.asynthesize(query_bundle=query_str, nodes=nodes) @@ -153,10 +154,10 @@ async def get_openbb_chat_output_executed( query_str: str, auto_llama_index: AutoLlamaIndex, python_repl_utility: PythonREPL, - node_postprocessor: Optional[BaseNodePostprocessor] = None, + node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, openbb_pat: Optional[str] = None, ) -> str: - output_res = await get_openbb_chat_output(query_str, auto_llama_index, node_postprocessor) + output_res = await get_openbb_chat_output(query_str, auto_llama_index, node_postprocessors) code_str = ( output_res.response.split("```python")[1].split("```")[0] if "```python" in output_res.response