|
1 | 1 | import json
|
2 | 2 | from collections.abc import Awaitable, Callable
|
3 |
| -from functools import partial |
| 3 | +from functools import partial, wraps |
4 | 4 | from typing import Any
|
5 | 5 | from uuid import UUID
|
6 | 6 |
|
@@ -118,17 +118,19 @@ async def call_tool(developer_id: UUID, tool_name: str, arguments: dict):
|
118 | 118 |
|
119 | 119 | connection_pool = getattr(app.state, "postgres_pool", None)
|
120 | 120 | tool_handler = partial(tool_handler, connection_pool=connection_pool)
|
121 |
| - arguments["developer_id"] = str(developer_id) |
| 121 | + arguments["developer_id"] = developer_id |
122 | 122 |
|
123 | 123 | # Convert all UUIDs to UUID objects
|
124 | 124 | uuid_fields = ["agent_id", "user_id", "task_id", "session_id", "doc_id"]
|
125 | 125 | for field in uuid_fields:
|
126 | 126 | if field in arguments:
|
127 |
| - arguments[field] = UUID(arguments[field]) |
| 127 | + fld = arguments[field] |
| 128 | + if isinstance(fld, str): |
| 129 | + arguments[field] = UUID(fld) |
128 | 130 |
|
129 | 131 | parts = tool_name.split(".")
|
130 | 132 | if len(parts) < MIN_TOOL_NAME_SEGMENTS:
|
131 |
| - msg = f"wrong syste tool name: {tool_name}" |
| 133 | + msg = f"invalid system tool name: {tool_name}" |
132 | 134 | raise NameError(msg)
|
133 | 135 |
|
134 | 136 | resource, subresource, operation = parts[0], None, parts[-1]
|
@@ -220,39 +222,47 @@ async def call_tool(developer_id: UUID, tool_name: str, arguments: dict):
|
220 | 222 | return await tool_handler(**arguments)
|
221 | 223 |
|
222 | 224 |
|
223 |
| -async def eval_tool_calls( |
224 |
| - func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]], |
| 225 | +def tool_calls_evaluator( |
| 226 | + *, |
225 | 227 | tool_types: set[str],
|
226 | 228 | developer_id: UUID,
|
227 |
| - **kwargs, |
228 | 229 | ):
|
229 |
| - response: ModelResponse | CustomStreamWrapper | None = None |
230 |
| - done = False |
231 |
| - while not done: |
232 |
| - response: ModelResponse | CustomStreamWrapper = await func(**kwargs) |
233 |
| - if not response.choices or not response.choices[0].message.tool_calls: |
| 230 | + def decor( |
| 231 | + func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]], |
| 232 | + ): |
| 233 | + @wraps(func) |
| 234 | + async def wrapper(**kwargs): |
| 235 | + response: ModelResponse | CustomStreamWrapper | None = None |
| 236 | + done = False |
| 237 | + while not done: |
| 238 | + response: ModelResponse | CustomStreamWrapper = await func(**kwargs) |
| 239 | + if not response.choices or not response.choices[0].message.tool_calls: |
| 240 | + return response |
| 241 | + |
| 242 | + # TODO: add streaming response handling |
| 243 | + for tool in response.choices[0].message.tool_calls: |
| 244 | + if tool.type not in tool_types: |
| 245 | + done = True |
| 246 | + continue |
| 247 | + |
| 248 | + done = False |
| 249 | + # call a tool |
| 250 | + tool_name = tool.function.name |
| 251 | + tool_args = json.loads(tool.function.arguments) |
| 252 | + tool_response = await call_tool(developer_id, tool_name, tool_args) |
| 253 | + |
| 254 | + # append result to messages from previous step |
| 255 | + messages: list = kwargs.get("messages", []) |
| 256 | + messages.append({ |
| 257 | + "tool_call_id": tool.id, |
| 258 | + "role": "tool", |
| 259 | + "name": tool_name, |
| 260 | + "content": tool_response, |
| 261 | + }) |
| 262 | + kwargs["messages"] = messages |
| 263 | + |
234 | 264 | return response
|
235 | 265 |
|
236 |
| - # TODO: add streaming response handling |
237 |
| - for tool in response.choices[0].message.tool_calls: |
238 |
| - if tool.type not in tool_types: |
239 |
| - done = True |
240 |
| - continue |
| 266 | + return wrapper |
241 | 267 |
|
242 |
| - done = False |
243 |
| - # call a tool |
244 |
| - tool_name = tool.function.name |
245 |
| - tool_args = json.loads(tool.function.arguments) |
246 |
| - tool_response = await call_tool(developer_id, tool_name, tool_args) |
247 |
| - |
248 |
| - # append result to messages from previous step |
249 |
| - messages: list = kwargs.get("messages", []) |
250 |
| - messages.append({ |
251 |
| - "tool_call_id": tool.id, |
252 |
| - "role": "tool", |
253 |
| - "name": tool_name, |
254 |
| - "content": tool_response, |
255 |
| - }) |
256 |
| - kwargs["messages"] = messages |
257 |
| - |
258 |
| - return response |
| 268 | + return decor |
0 commit comments