Skip to content

Commit

Permalink
Merge pull request #28 from GPTStonks/daedalus/improve-qa-with-comple…
Browse files Browse the repository at this point in the history
…te-query

Add the complete query to the QA tool
  • Loading branch information
Dedalo314 authored Dec 25, 2023
2 parents ada8148 + ccf8f7b commit 76899de
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
3 changes: 3 additions & 0 deletions gptstonks_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ async def run_model_in_background(query: str, use_agent: bool, openbb_pat: str |
if use_agent:
# update openbb tool with PAT (None if not provided)
app.tools[-1].coroutine = partial(app.tools[-1].coroutine, openbb_pat=openbb_pat)
# update QA tools to use original query to respond from the context
app.tools[0].coroutine = partial(app.tools[0].coroutine, original_query=query)
app.tools[1].coroutine = partial(app.tools[1].coroutine, original_query=query)
agent_executor = initialize_agent(
tools=app.tools,
llm=app.llm,
Expand Down
17 changes: 13 additions & 4 deletions gptstonks_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,22 @@ def run_qa_over_tool_output(tool_input: str | dict, llm: BaseLLM, tool: BaseTool
return f"> Context retrieved using {tool.name}.\n\n" f"{answer}"


async def arun_qa_over_tool_output(tool_input: str | dict, llm: BaseLLM, tool: BaseTool) -> str:
async def arun_qa_over_tool_output(
tool_input: str | dict, llm: BaseLLM, tool: BaseTool, original_query: Optional[str] = None
) -> str:
tool_output: str = await tool.arun(tool_input)
model_prompt: str = PromptTemplate(
model_prompt = PromptTemplate(
input_variables=["context_str", "query_str"],
template=os.getenv("CUSTOM_GPTSTONKS_QA", DEFAULT_TEXT_QA_PROMPT_TMPL),
).format(query_str=tool_input, context_str=tool_output)
answer: str = await llm.apredict(model_prompt)
)
if original_query is not None:
answer: str = await llm.apredict(
model_prompt.format(query_str=original_query, context_str=tool_output)
)
else:
answer: str = await llm.apredict(
model_prompt.format(query_str=tool_input, context_str=tool_output)
)

return f"> Context retrieved using {tool.name}.\n\n" f"{answer}"

Expand Down

0 comments on commit 76899de

Please sign in to comment.