diff --git a/Changelog.md b/Changelog.md index 93de215..e6278ad 100644 --- a/Changelog.md +++ b/Changelog.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - bugfix/42-redis-is-default-vector-db (2024-05-26) - bugfix/33-groq-rag-not-working (2024-05-19) ### Changed + - feature/25-ai-create-tool (2024-08-19) - feature/14-settings-page (2024-08-18) - feature/16-improve-oauth-callback-page-ux (2024-08-18) - feature/17-google-oauth (2024-08-18) diff --git a/backend/src/chains/agent.py b/backend/src/chains/agent.py index e8e549c..e6cc42f 100644 --- a/backend/src/chains/agent.py +++ b/backend/src/chains/agent.py @@ -26,7 +26,8 @@ def agent_chain(body: Agent, endpoints: list[dict] = None, user_id: str = None): tools = gather_tools( tools=body.tools, retriever=retriever, - endpoints=endpoints + endpoints=endpoints, + user_id=user_id ) agent = llm_service.agent( system=system, diff --git a/backend/src/config/tool.py b/backend/src/config/tool.py index e261f59..a435d64 100644 --- a/backend/src/config/tool.py +++ b/backend/src/config/tool.py @@ -38,6 +38,12 @@ 'link': '/tools/pdf_fill_form_fields', 'toolkit': 'Advanced' }, + 'create_api_tool': { + 'name': 'Create API Tool', + 'description': 'Create a new API tool.', + 'link': '/tools/create_api_tool', + 'toolkit': 'Advanced' + }, } ## Available tools diff --git a/backend/src/repositories/tool.py b/backend/src/repositories/tool.py index 9b47668..a11f6de 100644 --- a/backend/src/repositories/tool.py +++ b/backend/src/repositories/tool.py @@ -58,7 +58,7 @@ async def endpoints(self): async def list(self): endpoints = await self.endpoints() - tools = tool_details(endpoints) + tools = tool_details(endpoints, user_id=self.user_id) return tools async def create(self, tool: APITool): @@ -224,9 +224,9 @@ async def delete(self, tool_name: str): return True return None -def tool_repo(request: Request, db: AsyncSession = Depends(get_db)) -> ToolRepository: +def tool_repo(request: Request = None, db: AsyncSession = Depends(get_db), user_id: str = None) -> ToolRepository: try: - return ToolRepository(request=request, db=db) + return ToolRepository(request=request, db=db, user_id=user_id) except NotFoundException as e: # Handle specific NotFoundException with a custom message or logging logging.warning(f"Failed to initialize ToolRepository: {str(e)}") diff --git a/backend/src/routes/user.py b/backend/src/routes/user.py index 5452f05..0ae0e5e 100644 --- a/backend/src/routes/user.py +++ b/backend/src/routes/user.py @@ -152,6 +152,7 @@ async def auth_callback(provider: str, code: str, db: AsyncSession = Depends(get except ValueError as e: return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: + logging.exception(str(e)) return UJSONResponse(detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) ################################################################################################################## diff --git a/backend/src/tools/advanced.py b/backend/src/tools/advanced.py index 0ea79f3..dcb57d4 100644 --- a/backend/src/tools/advanced.py +++ b/backend/src/tools/advanced.py @@ -1,3 +1,87 @@ from langchain_experimental.utilities import PythonREPL +from langchain.tools import StructuredTool +from langchain_core.tools import ToolException -python_repl = PythonREPL() \ No newline at end of file +from src.models.tools.api import APITool +from src.services.db import get_db +from src.tools.api import create_schema + +python_repl = PythonREPL() + +def construct_user_tool( + user_id: str, + tool_name: str, + tool_description: str, + args: dict = { + 'name': { + 'description': 'Name of the tool.', + 'type': 'str', + 'default': 'this_is_a_test_tool', + 'required': True + }, + 'description': { + 'description': 'Description of the tool.', + 'type': 'str', + 'default': '', + 'required': True + }, + 'link': { + 'description': 'Link to documentation for the tool.', + 'type': 'str', + 'default': '', + 'required': False + }, + 'toolkit': { + 'description': 'Group of tools.', + 'type': 'str', + 'default': 'Advanced', + 'required': True + }, + 'url': { + 'description': 'URL endpoint for the tool. Can interpolate args into url. Example: https://jsonplaceholder.typicode.com/posts/{post_id}?userId={user_id}', + 'type': 'str', + 'default': '', + 'required': True + }, + 'method': { + 'description': 'HTTP method for the tool. GET, POST, PUT, DELETE.', + 'type': 'str', + 'default': 'GET', + 'required': True + }, + 'headers': { + 'description': """Headers for the tool. Example: {"Content-Type": {"value": "application/json; charset=UTF-8", "encrypted": False}}""", + 'type': 'dict', + 'default': {}, + 'required': False + }, + 'args': { + 'description': """Arguments for the tool. Can be used to interpolate into the URL. Example: {"project_id": {"description": "ID of the project", "type": "str", "default": '123456', "required": True}, "issue_id": {"description": "ID of the issue", "type": "str", "default": '654321', "required": True}}""", + 'type': 'dict', + 'default': {}, + 'required': False + } + }, +) -> StructuredTool: + from src.repositories.tool import tool_repo + if args: + args = create_schema(tool_name+"_schema", args) + + # Define a wrapper function to be used as the callable + async def endpoint_func(**kwargs): + try: + db_gen = get_db() + db = await db_gen.__anext__() + repo = tool_repo(user_id=user_id, db=db) + created = await repo.create(APITool(**kwargs)) + return {"tool": created['value']} + except Exception as e: + raise ToolException(str(e)) + + return StructuredTool.from_function( + name=tool_name, + coroutine=endpoint_func, # Pass the wrapper function here + description=tool_description, + args_schema=args, + handle_tool_error=True, + ) \ No newline at end of file diff --git a/backend/src/utils/tool.py b/backend/src/utils/tool.py index 1c51a71..e8c1b08 100644 --- a/backend/src/utils/tool.py +++ b/backend/src/utils/tool.py @@ -6,6 +6,7 @@ from src.models import RetrievalTool from src.config.tool import AVAILABLE_TOOLS, ENDPOINTS, TOOL_DESCRIPTIONS +from src.tools.advanced import construct_user_tool from src.tools.api import construct_api_tool from src.utils.format import flatten_array @@ -27,10 +28,11 @@ def gather_tools( available_tools: dict[str, any] = None, retriever: VectorStoreRetriever = None, plugins: list[str] = None, - endpoints: list[dict] = None + endpoints: list[dict] = None, + user_id: str = None ): """Gather tools from the tools list""" - constructed_available_tools, _ = construct_tools_and_descriptions(endpoints=endpoints or ENDPOINTS) + constructed_available_tools, _ = construct_tools_and_descriptions(endpoints=endpoints or ENDPOINTS, user_id=user_id) filtered_tools = filter_tools(tools or [], available_tools or constructed_available_tools) ## Add docs tool @@ -55,8 +57,8 @@ def gather_tools( return flatten_array(filtered_tools) -def tool_details(endpoints): - available_tools, tool_descriptions = construct_tools_and_descriptions(endpoints) +def tool_details(endpoints, user_id=None): + available_tools, tool_descriptions = construct_tools_and_descriptions(endpoints, user_id=user_id) return [ { **({'id': tool_descriptions[key]['id']} if 'id' in tool_descriptions[key] else {}), @@ -70,7 +72,7 @@ def tool_details(endpoints): for key in available_tools ] -def construct_tools_and_descriptions(endpoints): +def construct_tools_and_descriptions(endpoints, user_id=None): # Create copies of the original constants available_tools = copy.deepcopy(AVAILABLE_TOOLS) tool_descriptions = copy.deepcopy(TOOL_DESCRIPTIONS) @@ -88,5 +90,15 @@ def construct_tools_and_descriptions(endpoints): 'link': endpoint['link'], 'toolkit': endpoint['toolkit'] } + + if user_id: + tool_name = 'create_api_tool' + create_api_tool = construct_user_tool( + user_id=user_id, + tool_name=tool_name, + tool_description=tool_descriptions[tool_name]['description'], + ) + available_tools[tool_name] = create_api_tool + tool_descriptions[tool_name] = tool_descriptions[tool_name] return available_tools, tool_descriptions \ No newline at end of file