|
| 1 | +import json |
| 2 | +from collections.abc import Awaitable, Callable |
| 3 | +from functools import partial |
| 4 | +from typing import Any |
| 5 | +from uuid import UUID |
| 6 | + |
| 7 | +from beartype import beartype |
| 8 | +from fastapi.background import BackgroundTasks |
| 9 | +from litellm.utils import CustomStreamWrapper, ModelResponse |
| 10 | + |
| 11 | +from ...app import app |
| 12 | +from ...autogen.openapi_model import ( |
| 13 | + ChatInput, |
| 14 | + CreateDocRequest, |
| 15 | + CreateSessionRequest, |
| 16 | + HybridDocSearchRequest, |
| 17 | + TextOnlyDocSearchRequest, |
| 18 | + UpdateSessionRequest, |
| 19 | + UpdateUserRequest, |
| 20 | + VectorDocSearchRequest, |
| 21 | +) |
| 22 | +from ...queries.agents.create_agent import create_agent as create_agent_query |
| 23 | +from ...queries.agents.delete_agent import delete_agent as delete_agent_query |
| 24 | +from ...queries.agents.get_agent import get_agent as get_agent_query |
| 25 | +from ...queries.agents.list_agents import list_agents as list_agents_query |
| 26 | +from ...queries.agents.update_agent import update_agent as update_agent_query |
| 27 | +from ...queries.developers import get_developer |
| 28 | +from ...queries.docs.delete_doc import delete_doc as delete_doc_query |
| 29 | +from ...queries.docs.list_docs import list_docs as list_docs_query |
| 30 | +from ...queries.entries.get_history import get_history as get_history_query |
| 31 | +from ...queries.sessions.create_session import create_session as create_session_query |
| 32 | +from ...queries.sessions.get_session import get_session as get_session_query |
| 33 | +from ...queries.sessions.list_sessions import list_sessions as list_sessions_query |
| 34 | +from ...queries.sessions.update_session import update_session as update_session_query |
| 35 | +from ...queries.tasks.create_task import create_task as create_task_query |
| 36 | +from ...queries.tasks.delete_task import delete_task as delete_task_query |
| 37 | +from ...queries.tasks.get_task import get_task as get_task_query |
| 38 | +from ...queries.tasks.list_tasks import list_tasks as list_tasks_query |
| 39 | +from ...queries.tasks.update_task import update_task as update_task_query |
| 40 | +from ...queries.users.create_user import create_user as create_user_query |
| 41 | +from ...queries.users.delete_user import delete_user as delete_user_query |
| 42 | +from ...queries.users.get_user import get_user as get_user_query |
| 43 | +from ...queries.users.list_users import list_users as list_users_query |
| 44 | +from ...queries.users.update_user import update_user as update_user_query |
| 45 | + |
| 46 | +# FIXME: Do not use routes directly; |
| 47 | +from ...routers.docs.create_doc import create_agent_doc, create_user_doc |
| 48 | +from ...routers.docs.search_docs import search_agent_docs, search_user_docs |
| 49 | +from ...routers.sessions.chat import chat |
| 50 | + |
| 51 | +MIN_TOOL_NAME_SEGMENTS = 2 |
| 52 | + |
| 53 | + |
| 54 | +_system_tool_handlers = { |
| 55 | + "agent.doc.list": list_docs_query, |
| 56 | + "agent.doc.create": create_agent_doc, |
| 57 | + "agent.doc.delete": delete_doc_query, |
| 58 | + "agent.doc.search": search_agent_docs, |
| 59 | + "agent.list": list_agents_query, |
| 60 | + "agent.get": get_agent_query, |
| 61 | + "agent.create": create_agent_query, |
| 62 | + "agent.update": update_agent_query, |
| 63 | + "agent.delete": delete_agent_query, |
| 64 | + "user.doc.list": list_docs_query, |
| 65 | + "user.doc.create": create_user_doc, |
| 66 | + "user.doc.delete": delete_doc_query, |
| 67 | + "user.doc.search": search_user_docs, |
| 68 | + "user.list": list_users_query, |
| 69 | + "user.get": get_user_query, |
| 70 | + "user.create": create_user_query, |
| 71 | + "user.update": update_user_query, |
| 72 | + "user.delete": delete_user_query, |
| 73 | + "session.list": list_sessions_query, |
| 74 | + "session.get": get_session_query, |
| 75 | + "session.create": create_session_query, |
| 76 | + "session.update": update_session_query, |
| 77 | + "session.chat": chat, |
| 78 | + "session.history": get_history_query, |
| 79 | + "task.list": list_tasks_query, |
| 80 | + "task.get": get_task_query, |
| 81 | + "task.create": create_task_query, |
| 82 | + "task.update": update_task_query, |
| 83 | + "task.delete": delete_task_query, |
| 84 | +} |
| 85 | + |
| 86 | + |
| 87 | +def _create_search_request(arguments: dict) -> Any: |
| 88 | + """Create appropriate search request based on available parameters.""" |
| 89 | + if "text" in arguments and "vector" in arguments: |
| 90 | + return HybridDocSearchRequest( |
| 91 | + text=arguments.pop("text"), |
| 92 | + mmr_strength=arguments.pop("mmr_strength", 0), |
| 93 | + vector=arguments.pop("vector"), |
| 94 | + alpha=arguments.pop("alpha", 0.75), |
| 95 | + confidence=arguments.pop("confidence", 0.5), |
| 96 | + limit=arguments.get("limit", 10), |
| 97 | + ) |
| 98 | + if "text" in arguments: |
| 99 | + return TextOnlyDocSearchRequest( |
| 100 | + text=arguments.pop("text"), |
| 101 | + mmr_strength=arguments.pop("mmr_strength", 0), |
| 102 | + limit=arguments.get("limit", 10), |
| 103 | + ) |
| 104 | + if "vector" in arguments: |
| 105 | + return VectorDocSearchRequest( |
| 106 | + vector=arguments.pop("vector"), |
| 107 | + mmr_strength=arguments.pop("mmr_strength", 0), |
| 108 | + confidence=arguments.pop("confidence", 0.7), |
| 109 | + limit=arguments.get("limit", 10), |
| 110 | + ) |
| 111 | + return None |
| 112 | + |
| 113 | + |
| 114 | +@beartype |
| 115 | +async def call_tool(developer_id: UUID, tool_name: str, arguments: dict): |
| 116 | + tool_handler = _system_tool_handlers.get(tool_name) |
| 117 | + if not tool_handler: |
| 118 | + msg = f"System call not implemented for {tool_name}" |
| 119 | + raise NotImplementedError(msg) |
| 120 | + |
| 121 | + connection_pool = getattr(app.state, "postgres_pool", None) |
| 122 | + tool_handler = partial(tool_handler, connection_pool=connection_pool) |
| 123 | + arguments["developer_id"] = str(developer_id) |
| 124 | + |
| 125 | + # Convert all UUIDs to UUID objects |
| 126 | + uuid_fields = ["agent_id", "user_id", "task_id", "session_id", "doc_id"] |
| 127 | + for field in uuid_fields: |
| 128 | + if field in arguments: |
| 129 | + arguments[field] = UUID(arguments[field]) |
| 130 | + |
| 131 | + parts = tool_name.split(".") |
| 132 | + if len(parts) < MIN_TOOL_NAME_SEGMENTS: |
| 133 | + msg = f"wrong syste tool name: {tool_name}" |
| 134 | + raise NameError(msg) |
| 135 | + |
| 136 | + resource, subresource, operation = parts[0], None, parts[-1] |
| 137 | + if len(parts) > MIN_TOOL_NAME_SEGMENTS: |
| 138 | + subresource = parts[1] |
| 139 | + |
| 140 | + if subresource == "doc" and operation not in ["create", "search"]: |
| 141 | + owner_id_field = f"{resource}_id" |
| 142 | + if owner_id_field in arguments: |
| 143 | + doc_args = { |
| 144 | + "owner_type": resource, |
| 145 | + "owner_id": arguments[owner_id_field], |
| 146 | + **arguments, |
| 147 | + } |
| 148 | + doc_args.pop(owner_id_field) |
| 149 | + arguments = doc_args |
| 150 | + |
| 151 | + # Handle special cases for doc operations |
| 152 | + if operation == "create" and subresource == "doc": |
| 153 | + arguments["x_developer_id"] = arguments.pop("developer_id") |
| 154 | + return await tool_handler( |
| 155 | + data=CreateDocRequest(**arguments.pop("data")), |
| 156 | + **arguments, |
| 157 | + ) |
| 158 | + |
| 159 | + # Handle search operations |
| 160 | + if operation == "search" and subresource == "doc": |
| 161 | + arguments["x_developer_id"] = arguments.pop("developer_id") |
| 162 | + search_params = _create_search_request(arguments) |
| 163 | + return await tool_handler(search_params=search_params, **arguments) |
| 164 | + |
| 165 | + # Handle chat operations |
| 166 | + if operation == "chat" and resource == "session": |
| 167 | + developer = await get_developer( |
| 168 | + developer_id=arguments["developer_id"], |
| 169 | + connection_pool=connection_pool, |
| 170 | + ) # type: ignore[not-callable] |
| 171 | + |
| 172 | + session_id = arguments.get("session_id") |
| 173 | + x_custom_api_key = arguments.get("x_custom_api_key", None) |
| 174 | + chat_input = ChatInput(**arguments) |
| 175 | + bg_runner = BackgroundTasks() |
| 176 | + res = await tool_handler( |
| 177 | + developer=developer, |
| 178 | + session_id=session_id, |
| 179 | + background_tasks=bg_runner, |
| 180 | + x_custom_api_key=x_custom_api_key, |
| 181 | + chat_input=chat_input, |
| 182 | + ) |
| 183 | + await bg_runner() |
| 184 | + return res |
| 185 | + |
| 186 | + # Handle create session |
| 187 | + if operation == "create" and resource == "session": |
| 188 | + developer_id = arguments.pop("developer_id") |
| 189 | + session_id = arguments.pop("session_id", None) |
| 190 | + create_session_request = CreateSessionRequest(**arguments) |
| 191 | + |
| 192 | + return await tool_handler( |
| 193 | + developer_id=developer_id, |
| 194 | + session_id=session_id, |
| 195 | + data=create_session_request, |
| 196 | + ) |
| 197 | + |
| 198 | + # Handle update session |
| 199 | + if operation == "update" and resource == "session": |
| 200 | + developer_id = arguments.pop("developer_id") |
| 201 | + session_id = arguments.pop("session_id") |
| 202 | + update_session_request = UpdateSessionRequest(**arguments) |
| 203 | + |
| 204 | + return await tool_handler( |
| 205 | + developer_id=developer_id, |
| 206 | + session_id=session_id, |
| 207 | + data=update_session_request, |
| 208 | + ) |
| 209 | + |
| 210 | + # Handle update user |
| 211 | + if operation == "update" and resource == "user": |
| 212 | + developer_id = arguments.pop("developer_id") |
| 213 | + user_id = arguments.pop("user_id") |
| 214 | + update_user_request = UpdateUserRequest(**arguments) |
| 215 | + |
| 216 | + return await tool_handler( |
| 217 | + developer_id=developer_id, |
| 218 | + user_id=user_id, |
| 219 | + data=update_user_request, |
| 220 | + ) |
| 221 | + |
| 222 | + return await tool_handler(**arguments) |
| 223 | + |
| 224 | + |
| 225 | +async def eval_tool_calls( |
| 226 | + func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]], |
| 227 | + tool_types: set[str], |
| 228 | + developer_id: UUID, |
| 229 | + **kwargs, |
| 230 | +): |
| 231 | + response: ModelResponse | CustomStreamWrapper | None = None |
| 232 | + done = False |
| 233 | + while not done: |
| 234 | + response: ModelResponse | CustomStreamWrapper = await func(**kwargs) |
| 235 | + if not response.choices or not response.choices[0].message.tool_calls: |
| 236 | + return response |
| 237 | + |
| 238 | + # TODO: add streaming response handling |
| 239 | + for tool in response.choices[0].message.tool_calls: |
| 240 | + if tool.type not in tool_types: |
| 241 | + done = True |
| 242 | + continue |
| 243 | + |
| 244 | + done = False |
| 245 | + # call a tool |
| 246 | + tool_name = tool.function.name |
| 247 | + tool_args = json.loads(tool.function.arguments) |
| 248 | + tool_response = await call_tool(developer_id, tool_name, tool_args) |
| 249 | + |
| 250 | + # append result to messages from previous step |
| 251 | + messages: list = kwargs.get("messages", []) |
| 252 | + messages.append({ |
| 253 | + "tool_call_id": tool.id, |
| 254 | + "role": "tool", |
| 255 | + "name": tool_name, |
| 256 | + "content": tool_response, |
| 257 | + }) |
| 258 | + kwargs["messages"] = messages |
| 259 | + |
| 260 | + return response |
0 commit comments