Skip to content

Commit

Permalink
Add OpenAI agent to AutoMultiStepQueryEngine
Browse files Browse the repository at this point in the history
A new factory is added to construct multi-step agents based on OpenAI
function calling, which provides better results with OpenAI agents.
  • Loading branch information
Dedalo314 committed Mar 10, 2024
1 parent 393f9e5 commit 8a2509a
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 25 deletions.
175 changes: 151 additions & 24 deletions openbb_chat/kernels/auto_multistep_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,32 @@
from typing import Awaitable, Callable, List, Optional

from llama_index.agent.openai import OpenAIAgent
from llama_index.core.agent import ReActAgent
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.indices.query.query_transform.base import (
StepDecomposeQueryTransform,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.query_engine import MultiStepQueryEngine
from llama_index.core.tools import FunctionTool
from llama_index.core.tools import BaseTool, FunctionTool
from pydantic import BaseModel


class SimpleLlamaIndexReActAgent(ReActAgent):
"""LlamaIndex agent based over LangChain agent tools.
LlamaIndex agents inherit from Query Engines so they can be converted directly to other tools
to be used by new agents (hierarchical agents).
"""
class BaseLlamaIndexAgent(BaseModel):
"""Base class for agents based on functions and names."""

@classmethod
def from_scratch(
def tools_from_scratch(
cls,
funcs: List[Callable],
names: List[str],
descriptions: List[str],
async_funcs: Optional[List[Callable[..., Awaitable]]] = None,
**kwargs,
) -> "ReActAgent":
"""Convenience constructor method from set of callables.
Extra arguments are sent directly to `ReActAgent.from_tools`.
) -> list[BaseTool]:
"""Convenience intermediate method to create factories for agents.
Returns:
ReActAgent
`list[BaseTool]`: list of LlamaIndex tools.
"""

if len(funcs) == 0:
Expand All @@ -53,32 +49,124 @@ def from_scratch(
)
)

return tools


class SimpleLlamaIndexReActAgent(ReActAgent):
"""LlamaIndex agent based over LangChain agent tools.
LlamaIndex agents inherit from Query Engines so they can be converted directly to other tools
to be used by new agents (hierarchical agents).
"""

@classmethod
def from_scratch(
cls,
funcs: List[Callable],
names: List[str],
descriptions: List[str],
async_funcs: Optional[List[Callable[..., Awaitable]]] = None,
**kwargs,
) -> "ReActAgent":
"""Convenience constructor method from set of callables.
Extra arguments are sent directly to `ReActAgent.from_tools`.
Returns:
ReActAgent
"""

tools = BaseLlamaIndexAgent.tools_from_scratch(
funcs=funcs, names=names, descriptions=descriptions, async_funcs=async_funcs
)

return cls.from_tools(tools, **kwargs)


class SimpleLlamaIndexOpenAIAgent(OpenAIAgent):
"""LlamaIndex OpenAI agent based over LangChain agent tools.
LlamaIndex agents inherit from Query Engines so they can be converted directly to other tools
to be used by new agents (hierarchical agents).
"""

@classmethod
def from_scratch(
cls,
funcs: List[Callable],
names: List[str],
descriptions: List[str],
async_funcs: Optional[List[Callable[..., Awaitable]]] = None,
**kwargs,
) -> "OpenAIAgent":
"""Convenience constructor method from set of callables.
Extra arguments are sent directly to `ReActAgent.from_tools`.
Returns:
OpenAIAgent
"""

tools = BaseLlamaIndexAgent.tools_from_scratch(
funcs=funcs, names=names, descriptions=descriptions, async_funcs=async_funcs
)

return cls.from_tools(tools, **kwargs)


class AutoMultiStepQueryEngine(MultiStepQueryEngine):
"""Auto class for creating a query engine."""

@classmethod
def from_query_engine(
cls,
llm: LLM,
query_engine: BaseQueryEngine,
index_summary: str,
verbose: bool = False,
step_decompose_query_transform_kwargs: dict = {},
**kwargs,
) -> "AutoMultiStepQueryEngine":
"""Convenience constructor method from a query_engine.
Extra arguments are sent directly to `MultiStepQueryEngine` constructor.
Returns:
AutoMultiStepQueryEngine
"""

# transform to decompose input query in multiple queries
step_decompose_transform = StepDecomposeQueryTransform(
llm=llm, verbose=verbose, **step_decompose_query_transform_kwargs
)

return cls(
query_engine=query_engine,
query_transform=step_decompose_transform,
index_summary=index_summary,
**kwargs,
)

@classmethod
def from_simple_react_agent(
cls,
llm: LLM,
funcs: List[Callable],
names: List[str],
descriptions: List[str],
index_summary: str,
async_funcs: Optional[List[Callable[..., Awaitable]]] = None,
index_summary: Optional[str] = None,
verbose: bool = False,
step_decompose_query_transform_kwargs: dict = {},
simple_llama_index_react_agent_kwargs: dict = {},
llama_index_agent_from_tools_kwargs: dict = {},
**kwargs,
) -> "MultiStepQueryEngine":
"""Convenience constructor method from set of callables.
Extra arguments are sent directly to `MultiStepQueryEngine` constructor.
Returns:
ReActAgent
MultiStepQueryEngine
"""

# llama-index agent, inherits from query engine
Expand All @@ -89,17 +177,56 @@ def from_simple_react_agent(
descriptions=descriptions,
llm=llm,
verbose=verbose,
**simple_llama_index_react_agent_kwargs,
**llama_index_agent_from_tools_kwargs,
)

# transform to decompose input query in multiple queries
step_decompose_transform = StepDecomposeQueryTransform(
llm=llm, verbose=verbose, **step_decompose_query_transform_kwargs
return cls.from_query_engine(
llm=llm,
query_engine=agent,
index_summary=index_summary,
verbose=verbose,
step_decompose_query_transform_kwargs=step_decompose_query_transform_kwargs,
**kwargs,
)

return cls(
@classmethod
def from_simple_openai_agent(
cls,
llm: LLM,
funcs: List[Callable],
names: List[str],
descriptions: List[str],
index_summary: str,
async_funcs: Optional[List[Callable[..., Awaitable]]] = None,
verbose: bool = False,
step_decompose_query_transform_kwargs: dict = {},
llama_index_agent_from_tools_kwargs: dict = {},
**kwargs,
) -> "MultiStepQueryEngine":
"""Convenience constructor method from set of callables.
Extra arguments are sent directly to `MultiStepQueryEngine` constructor.
Returns:
MultiStepQueryEngine
"""

# llama-index agent, inherits from query engine
agent = SimpleLlamaIndexOpenAIAgent.from_scratch(
funcs=funcs,
async_funcs=async_funcs,
names=names,
descriptions=descriptions,
llm=llm,
verbose=verbose,
**llama_index_agent_from_tools_kwargs,
)

return cls.from_query_engine(
llm=llm,
query_engine=agent,
query_transform=step_decompose_transform,
index_summary=index_summary or "Used to answer complex questions using the Internet",
index_summary=index_summary,
verbose=verbose,
step_decompose_query_transform_kwargs=step_decompose_query_transform_kwargs,
**kwargs,
)
59 changes: 58 additions & 1 deletion tests/kernels/test_auto_multistep_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@pytest.mark.asyncio
@patch.object(BaseQueryEngine, "query")
@patch.object(BaseQueryEngine, "aquery")
async def test_factory(mocked_aquery, mocked_query):
async def test_factory_react_agent(mocked_aquery, mocked_query):
os.environ["OPENAI_API_KEY"] = "sk-..."

# LangChain BaseTools to use
Expand Down Expand Up @@ -64,6 +64,63 @@ async def request_get_tool_async_func(x):
request_get_tool.description,
],
verbose=True,
index_summary="Useful to get information on the Internet",
)
query_engine.query("Whatever")
await query_engine.aquery("Whatever again")

assert len(query_engine.get_prompts()) == 3
mocked_aquery.assert_called_once()
mocked_query.assert_called_once()


@pytest.mark.asyncio
@patch.object(BaseQueryEngine, "query")
@patch.object(BaseQueryEngine, "aquery")
async def test_factory_openai_agent(mocked_aquery, mocked_query):
os.environ["OPENAI_API_KEY"] = "sk-..."

# LangChain BaseTools to use
search_tool = DuckDuckGoSearchResults(api_wrapper=DuckDuckGoSearchAPIWrapper())
wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
request_get_tool = RequestsGetTool(requests_wrapper=TextRequestsWrapper())

# base parameters for the auto model
def search_tool_func(x):
return search_tool.run(x)

def wikipedia_tool_func(x):
return wikipedia_tool.run(x)

def request_get_tool_func(x):
return request_get_tool.run(x)

async def search_tool_async_func(x):
return await search_tool.arun(x)

async def wikipedia_tool_async_func(x):
return await wikipedia_tool.arun(x)

async def request_get_tool_async_func(x):
return await request_get_tool.arun(x)

# create query engine using factory
query_engine = AutoMultiStepQueryEngine.from_simple_openai_agent(
llm=OpenAI(model="gpt-4-0125-preview"),
funcs=[search_tool_func, wikipedia_tool_func, request_get_tool_func],
async_funcs=[
search_tool_async_func,
wikipedia_tool_async_func,
request_get_tool_async_func,
],
names=[search_tool.name, wikipedia_tool.name, request_get_tool.name],
descriptions=[
search_tool.description,
wikipedia_tool.description,
request_get_tool.description,
],
verbose=True,
index_summary="Useful to get information on the Internet",
)
query_engine.query("Whatever")
await query_engine.aquery("Whatever again")
Expand Down

0 comments on commit 8a2509a

Please sign in to comment.