Skip to content

Commit

Permalink
Add multiple node postprocessors
Browse files Browse the repository at this point in the history
Multiple node postprocessors can now be used instead of just one.
  • Loading branch information
Dedalo314 committed Dec 26, 2023
1 parent ce63197 commit d7ed470
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
18 changes: 12 additions & 6 deletions gptstonks_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from llama_index.embeddings import OpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingModelType
from llama_index.llms import LangChainLLM
from llama_index.postprocessor import MetadataReplacementPostProcessor
from llama_index.postprocessor import (
MetadataReplacementPostProcessor,
SimilarityPostprocessor,
)
from openbb import obb
from openbb_chat.kernels.auto_llama_index import AutoLlamaIndex

Expand Down Expand Up @@ -158,10 +161,13 @@ def init_data():
app.AI_PREFIX = "GPTSTONKS_RESPONSE"
app.python_repl_utility = PythonREPL()
app.python_repl_utility.globals = globals()
app.node_postprocessor = (
MetadataReplacementPostProcessor(target_metadata_key="extra_context")
app.node_postprocessors = (
[
SimilarityPostprocessor(similarity_cutoff=0.8),
MetadataReplacementPostProcessor(target_metadata_key="extra_context"),
]
if os.getenv("REMOVE_POSTPROCESSOR", None) is None
else None
else [SimilarityPostprocessor(similarity_cutoff=0.8)]
)
search_tool = DuckDuckGoSearchResults(api_wrapper=DuckDuckGoSearchAPIWrapper())
wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
Expand Down Expand Up @@ -195,7 +201,7 @@ def init_data():
get_openbb_chat_output_executed,
auto_llama_index=app.auto_llama_index,
python_repl_utility=app.python_repl_utility,
node_postprocessor=app.node_postprocessor,
node_postprocessors=app.node_postprocessors,
),
description=os.getenv("OPENBBCHAT_TOOL_DESCRIPTION"),
return_direct=True,
Expand Down Expand Up @@ -286,7 +292,7 @@ async def run_model_in_background(query: str, use_agent: bool, openbb_pat: str |
openbbchat_output = await get_openbb_chat_output(
query_str=query,
auto_llama_index=app.auto_llama_index,
node_postprocessor=app.node_postprocessor,
node_postprocessors=app.node_postprocessors,
)
code_str = (
openbbchat_output.response.split("```python")[1].split("```")[0]
Expand Down
11 changes: 6 additions & 5 deletions gptstonks_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ def get_keys_file():
async def get_openbb_chat_output(
query_str: str,
auto_llama_index: AutoLlamaIndex,
node_postprocessor: Optional[BaseNodePostprocessor] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
) -> str:
nodes = await auto_llama_index._retriever.aretrieve(query_str)
if node_postprocessor is not None:
nodes = node_postprocessor.postprocess_nodes(nodes)
if node_postprocessors is not None:
for node_postprocessor in node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(nodes)
return await auto_llama_index._query_engine.asynthesize(query_bundle=query_str, nodes=nodes)


Expand All @@ -153,10 +154,10 @@ async def get_openbb_chat_output_executed(
query_str: str,
auto_llama_index: AutoLlamaIndex,
python_repl_utility: PythonREPL,
node_postprocessor: Optional[BaseNodePostprocessor] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
openbb_pat: Optional[str] = None,
) -> str:
output_res = await get_openbb_chat_output(query_str, auto_llama_index, node_postprocessor)
output_res = await get_openbb_chat_output(query_str, auto_llama_index, node_postprocessors)
code_str = (
output_res.response.split("```python")[1].split("```")[0]
if "```python" in output_res.response
Expand Down

0 comments on commit d7ed470

Please sign in to comment.