Skip to content

Commit

Permalink
Merge pull request #31 from GPTStonks/daedalus/multiple-node-postproc…
Browse files Browse the repository at this point in the history
…essors

Add multiple node postprocessors
  • Loading branch information
Dedalo314 authored Dec 26, 2023
2 parents eca7511 + d7ed470 commit 011bf5e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
18 changes: 12 additions & 6 deletions gptstonks_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from llama_index.embeddings import OpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingModelType
from llama_index.llms import LangChainLLM
from llama_index.postprocessor import MetadataReplacementPostProcessor
from llama_index.postprocessor import (
MetadataReplacementPostProcessor,
SimilarityPostprocessor,
)
from openbb import obb
from openbb_chat.kernels.auto_llama_index import AutoLlamaIndex

Expand Down Expand Up @@ -158,10 +161,13 @@ def init_data():
app.AI_PREFIX = "GPTSTONKS_RESPONSE"
app.python_repl_utility = PythonREPL()
app.python_repl_utility.globals = globals()
app.node_postprocessor = (
MetadataReplacementPostProcessor(target_metadata_key="extra_context")
app.node_postprocessors = (
[
SimilarityPostprocessor(similarity_cutoff=0.8),
MetadataReplacementPostProcessor(target_metadata_key="extra_context"),
]
if os.getenv("REMOVE_POSTPROCESSOR", None) is None
else None
else [SimilarityPostprocessor(similarity_cutoff=0.8)]
)
search_tool = DuckDuckGoSearchResults(api_wrapper=DuckDuckGoSearchAPIWrapper())
wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
Expand Down Expand Up @@ -195,7 +201,7 @@ def init_data():
get_openbb_chat_output_executed,
auto_llama_index=app.auto_llama_index,
python_repl_utility=app.python_repl_utility,
node_postprocessor=app.node_postprocessor,
node_postprocessors=app.node_postprocessors,
),
description=os.getenv("OPENBBCHAT_TOOL_DESCRIPTION"),
return_direct=True,
Expand Down Expand Up @@ -286,7 +292,7 @@ async def run_model_in_background(query: str, use_agent: bool, openbb_pat: str |
openbbchat_output = await get_openbb_chat_output(
query_str=query,
auto_llama_index=app.auto_llama_index,
node_postprocessor=app.node_postprocessor,
node_postprocessors=app.node_postprocessors,
)
code_str = (
openbbchat_output.response.split("```python")[1].split("```")[0]
Expand Down
11 changes: 6 additions & 5 deletions gptstonks_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ def get_keys_file():
async def get_openbb_chat_output(
query_str: str,
auto_llama_index: AutoLlamaIndex,
node_postprocessor: Optional[BaseNodePostprocessor] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
) -> str:
nodes = await auto_llama_index._retriever.aretrieve(query_str)
if node_postprocessor is not None:
nodes = node_postprocessor.postprocess_nodes(nodes)
if node_postprocessors is not None:
for node_postprocessor in node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(nodes)
return await auto_llama_index._query_engine.asynthesize(query_bundle=query_str, nodes=nodes)


Expand All @@ -153,10 +154,10 @@ async def get_openbb_chat_output_executed(
query_str: str,
auto_llama_index: AutoLlamaIndex,
python_repl_utility: PythonREPL,
node_postprocessor: Optional[BaseNodePostprocessor] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
openbb_pat: Optional[str] = None,
) -> str:
output_res = await get_openbb_chat_output(query_str, auto_llama_index, node_postprocessor)
output_res = await get_openbb_chat_output(query_str, auto_llama_index, node_postprocessors)
code_str = (
output_res.response.split("```python")[1].split("```")[0]
if "```python" in output_res.response
Expand Down

0 comments on commit 011bf5e

Please sign in to comment.