diff --git a/openbb_chat/kernels/auto_multistep_query_engine.py b/openbb_chat/kernels/auto_multistep_query_engine.py index 53e35f6..53c2120 100644 --- a/openbb_chat/kernels/auto_multistep_query_engine.py +++ b/openbb_chat/kernels/auto_multistep_query_engine.py @@ -1,36 +1,32 @@ from typing import Awaitable, Callable, List, Optional +from llama_index.agent.openai import OpenAIAgent from llama_index.core.agent import ReActAgent +from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.indices.query.query_transform.base import ( StepDecomposeQueryTransform, ) from llama_index.core.llms.llm import LLM from llama_index.core.query_engine import MultiStepQueryEngine -from llama_index.core.tools import FunctionTool +from llama_index.core.tools import BaseTool, FunctionTool +from pydantic import BaseModel -class SimpleLlamaIndexReActAgent(ReActAgent): - """LlamaIndex agent based over LangChain agent tools. - - LlamaIndex agents inherit from Query Engines so they can be converted directly to other tools - to be used by new agents (hierarchical agents). - """ +class BaseLlamaIndexAgent(BaseModel): + """Base class for agents based on functions and names.""" @classmethod - def from_scratch( + def tools_from_scratch( cls, funcs: List[Callable], names: List[str], descriptions: List[str], async_funcs: Optional[List[Callable[..., Awaitable]]] = None, - **kwargs, - ) -> "ReActAgent": - """Convenience constructor method from set of callables. - - Extra arguments are sent directly to `ReActAgent.from_tools`. + ) -> list[BaseTool]: + """Convenience intermediate method to create factories for agents. Returns: - ReActAgent + `list[BaseTool]`: list of LlamaIndex tools. """ if len(funcs) == 0: @@ -53,12 +49,104 @@ def from_scratch( ) ) + return tools + + +class SimpleLlamaIndexReActAgent(ReActAgent): + """LlamaIndex agent based over LangChain agent tools. + + LlamaIndex agents inherit from Query Engines so they can be converted directly to other tools + to be used by new agents (hierarchical agents). + """ + + @classmethod + def from_scratch( + cls, + funcs: List[Callable], + names: List[str], + descriptions: List[str], + async_funcs: Optional[List[Callable[..., Awaitable]]] = None, + **kwargs, + ) -> "ReActAgent": + """Convenience constructor method from set of callables. + + Extra arguments are sent directly to `ReActAgent.from_tools`. + + Returns: + ReActAgent + """ + + tools = BaseLlamaIndexAgent.tools_from_scratch( + funcs=funcs, names=names, descriptions=descriptions, async_funcs=async_funcs + ) + + return cls.from_tools(tools, **kwargs) + + +class SimpleLlamaIndexOpenAIAgent(OpenAIAgent): + """LlamaIndex OpenAI agent based over LangChain agent tools. + + LlamaIndex agents inherit from Query Engines so they can be converted directly to other tools + to be used by new agents (hierarchical agents). + """ + + @classmethod + def from_scratch( + cls, + funcs: List[Callable], + names: List[str], + descriptions: List[str], + async_funcs: Optional[List[Callable[..., Awaitable]]] = None, + **kwargs, + ) -> "OpenAIAgent": + """Convenience constructor method from set of callables. + + Extra arguments are sent directly to `ReActAgent.from_tools`. + + Returns: + OpenAIAgent + """ + + tools = BaseLlamaIndexAgent.tools_from_scratch( + funcs=funcs, names=names, descriptions=descriptions, async_funcs=async_funcs + ) + return cls.from_tools(tools, **kwargs) class AutoMultiStepQueryEngine(MultiStepQueryEngine): """Auto class for creating a query engine.""" + @classmethod + def from_query_engine( + cls, + llm: LLM, + query_engine: BaseQueryEngine, + index_summary: str, + verbose: bool = False, + step_decompose_query_transform_kwargs: dict = {}, + **kwargs, + ) -> "AutoMultiStepQueryEngine": + """Convenience constructor method from a query_engine. + + Extra arguments are sent directly to `MultiStepQueryEngine` constructor. + + Returns: + AutoMultiStepQueryEngine + """ + + # transform to decompose input query in multiple queries + step_decompose_transform = StepDecomposeQueryTransform( + llm=llm, verbose=verbose, **step_decompose_query_transform_kwargs + ) + + return cls( + query_engine=query_engine, + query_transform=step_decompose_transform, + index_summary=index_summary, + **kwargs, + ) + @classmethod def from_simple_react_agent( cls, @@ -66,11 +154,11 @@ def from_simple_react_agent( funcs: List[Callable], names: List[str], descriptions: List[str], + index_summary: str, async_funcs: Optional[List[Callable[..., Awaitable]]] = None, - index_summary: Optional[str] = None, verbose: bool = False, step_decompose_query_transform_kwargs: dict = {}, - simple_llama_index_react_agent_kwargs: dict = {}, + llama_index_agent_from_tools_kwargs: dict = {}, **kwargs, ) -> "MultiStepQueryEngine": """Convenience constructor method from set of callables. @@ -78,7 +166,7 @@ def from_simple_react_agent( Extra arguments are sent directly to `MultiStepQueryEngine` constructor. Returns: - ReActAgent + MultiStepQueryEngine """ # llama-index agent, inherits from query engine @@ -89,17 +177,56 @@ def from_simple_react_agent( descriptions=descriptions, llm=llm, verbose=verbose, - **simple_llama_index_react_agent_kwargs, + **llama_index_agent_from_tools_kwargs, ) - # transform to decompose input query in multiple queries - step_decompose_transform = StepDecomposeQueryTransform( - llm=llm, verbose=verbose, **step_decompose_query_transform_kwargs + return cls.from_query_engine( + llm=llm, + query_engine=agent, + index_summary=index_summary, + verbose=verbose, + step_decompose_query_transform_kwargs=step_decompose_query_transform_kwargs, + **kwargs, ) - return cls( + @classmethod + def from_simple_openai_agent( + cls, + llm: LLM, + funcs: List[Callable], + names: List[str], + descriptions: List[str], + index_summary: str, + async_funcs: Optional[List[Callable[..., Awaitable]]] = None, + verbose: bool = False, + step_decompose_query_transform_kwargs: dict = {}, + llama_index_agent_from_tools_kwargs: dict = {}, + **kwargs, + ) -> "MultiStepQueryEngine": + """Convenience constructor method from set of callables. + + Extra arguments are sent directly to `MultiStepQueryEngine` constructor. + + Returns: + MultiStepQueryEngine + """ + + # llama-index agent, inherits from query engine + agent = SimpleLlamaIndexOpenAIAgent.from_scratch( + funcs=funcs, + async_funcs=async_funcs, + names=names, + descriptions=descriptions, + llm=llm, + verbose=verbose, + **llama_index_agent_from_tools_kwargs, + ) + + return cls.from_query_engine( + llm=llm, query_engine=agent, - query_transform=step_decompose_transform, - index_summary=index_summary or "Used to answer complex questions using the Internet", + index_summary=index_summary, + verbose=verbose, + step_decompose_query_transform_kwargs=step_decompose_query_transform_kwargs, **kwargs, ) diff --git a/tests/kernels/test_auto_multistep_query_engine.py b/tests/kernels/test_auto_multistep_query_engine.py index 781c65e..bb21993 100644 --- a/tests/kernels/test_auto_multistep_query_engine.py +++ b/tests/kernels/test_auto_multistep_query_engine.py @@ -21,7 +21,7 @@ @pytest.mark.asyncio @patch.object(BaseQueryEngine, "query") @patch.object(BaseQueryEngine, "aquery") -async def test_factory(mocked_aquery, mocked_query): +async def test_factory_react_agent(mocked_aquery, mocked_query): os.environ["OPENAI_API_KEY"] = "sk-..." # LangChain BaseTools to use @@ -64,6 +64,63 @@ async def request_get_tool_async_func(x): request_get_tool.description, ], verbose=True, + index_summary="Useful to get information on the Internet", + ) + query_engine.query("Whatever") + await query_engine.aquery("Whatever again") + + assert len(query_engine.get_prompts()) == 3 + mocked_aquery.assert_called_once() + mocked_query.assert_called_once() + + +@pytest.mark.asyncio +@patch.object(BaseQueryEngine, "query") +@patch.object(BaseQueryEngine, "aquery") +async def test_factory_openai_agent(mocked_aquery, mocked_query): + os.environ["OPENAI_API_KEY"] = "sk-..." + + # LangChain BaseTools to use + search_tool = DuckDuckGoSearchResults(api_wrapper=DuckDuckGoSearchAPIWrapper()) + wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) + request_get_tool = RequestsGetTool(requests_wrapper=TextRequestsWrapper()) + + # base parameters for the auto model + def search_tool_func(x): + return search_tool.run(x) + + def wikipedia_tool_func(x): + return wikipedia_tool.run(x) + + def request_get_tool_func(x): + return request_get_tool.run(x) + + async def search_tool_async_func(x): + return await search_tool.arun(x) + + async def wikipedia_tool_async_func(x): + return await wikipedia_tool.arun(x) + + async def request_get_tool_async_func(x): + return await request_get_tool.arun(x) + + # create query engine using factory + query_engine = AutoMultiStepQueryEngine.from_simple_openai_agent( + llm=OpenAI(model="gpt-4-0125-preview"), + funcs=[search_tool_func, wikipedia_tool_func, request_get_tool_func], + async_funcs=[ + search_tool_async_func, + wikipedia_tool_async_func, + request_get_tool_async_func, + ], + names=[search_tool.name, wikipedia_tool.name, request_get_tool.name], + descriptions=[ + search_tool.description, + wikipedia_tool.description, + request_get_tool.description, + ], + verbose=True, + index_summary="Useful to get information on the Internet", ) query_engine.query("Whatever") await query_engine.aquery("Whatever again")