Skip to content

Commit 8e299d8

Browse files
chore: Refactor the evaluator and add tests
1 parent ea7b342 commit 8e299d8

File tree

3 files changed

+575
-38
lines changed

3 files changed

+575
-38
lines changed

agents-api/agents_api/routers/sessions/chat.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...queries.entries.create_entries import create_entries
2525
from ...queries.sessions.count_sessions import count_sessions as count_sessions_query
2626
from ..utils.model_validation import validate_model
27-
from ..utils.tools import eval_tool_calls
27+
from ..utils.tools import tool_calls_evaluator
2828
from .metrics import total_tokens_per_user
2929
from .router import router
3030

@@ -203,9 +203,12 @@ async def chat(
203203
"tags": developer.tags,
204204
"custom_api_key": x_custom_api_key,
205205
}
206-
model_response = await eval_tool_calls(
207-
litellm.acompletion, {"system"}, developer.id, **{**settings, **params}
208-
)
206+
evaluator = tool_calls_evaluator(tool_types={"system"}, developer_id=developer.id)
207+
acompletion = evaluator(litellm.acompletion)
208+
model_response = await acompletion(**{
209+
**settings,
210+
**params,
211+
})
209212

210213
# Save the input and the response to the session history
211214
if chat_input.save:

agents-api/agents_api/routers/utils/tools.py

+44-34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from collections.abc import Awaitable, Callable
3-
from functools import partial
3+
from functools import partial, wraps
44
from typing import Any
55
from uuid import UUID
66

@@ -118,17 +118,19 @@ async def call_tool(developer_id: UUID, tool_name: str, arguments: dict):
118118

119119
connection_pool = getattr(app.state, "postgres_pool", None)
120120
tool_handler = partial(tool_handler, connection_pool=connection_pool)
121-
arguments["developer_id"] = str(developer_id)
121+
arguments["developer_id"] = developer_id
122122

123123
# Convert all UUIDs to UUID objects
124124
uuid_fields = ["agent_id", "user_id", "task_id", "session_id", "doc_id"]
125125
for field in uuid_fields:
126126
if field in arguments:
127-
arguments[field] = UUID(arguments[field])
127+
fld = arguments[field]
128+
if isinstance(fld, str):
129+
arguments[field] = UUID(fld)
128130

129131
parts = tool_name.split(".")
130132
if len(parts) < MIN_TOOL_NAME_SEGMENTS:
131-
msg = f"wrong syste tool name: {tool_name}"
133+
msg = f"invalid system tool name: {tool_name}"
132134
raise NameError(msg)
133135

134136
resource, subresource, operation = parts[0], None, parts[-1]
@@ -220,39 +222,47 @@ async def call_tool(developer_id: UUID, tool_name: str, arguments: dict):
220222
return await tool_handler(**arguments)
221223

222224

223-
async def eval_tool_calls(
224-
func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]],
225+
def tool_calls_evaluator(
226+
*,
225227
tool_types: set[str],
226228
developer_id: UUID,
227-
**kwargs,
228229
):
229-
response: ModelResponse | CustomStreamWrapper | None = None
230-
done = False
231-
while not done:
232-
response: ModelResponse | CustomStreamWrapper = await func(**kwargs)
233-
if not response.choices or not response.choices[0].message.tool_calls:
230+
def decor(
231+
func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]],
232+
):
233+
@wraps(func)
234+
async def wrapper(**kwargs):
235+
response: ModelResponse | CustomStreamWrapper | None = None
236+
done = False
237+
while not done:
238+
response: ModelResponse | CustomStreamWrapper = await func(**kwargs)
239+
if not response.choices or not response.choices[0].message.tool_calls:
240+
return response
241+
242+
# TODO: add streaming response handling
243+
for tool in response.choices[0].message.tool_calls:
244+
if tool.type not in tool_types:
245+
done = True
246+
continue
247+
248+
done = False
249+
# call a tool
250+
tool_name = tool.function.name
251+
tool_args = json.loads(tool.function.arguments)
252+
tool_response = await call_tool(developer_id, tool_name, tool_args)
253+
254+
# append result to messages from previous step
255+
messages: list = kwargs.get("messages", [])
256+
messages.append({
257+
"tool_call_id": tool.id,
258+
"role": "tool",
259+
"name": tool_name,
260+
"content": tool_response,
261+
})
262+
kwargs["messages"] = messages
263+
234264
return response
235265

236-
# TODO: add streaming response handling
237-
for tool in response.choices[0].message.tool_calls:
238-
if tool.type not in tool_types:
239-
done = True
240-
continue
266+
return wrapper
241267

242-
done = False
243-
# call a tool
244-
tool_name = tool.function.name
245-
tool_args = json.loads(tool.function.arguments)
246-
tool_response = await call_tool(developer_id, tool_name, tool_args)
247-
248-
# append result to messages from previous step
249-
messages: list = kwargs.get("messages", [])
250-
messages.append({
251-
"tool_call_id": tool.id,
252-
"role": "tool",
253-
"name": tool_name,
254-
"content": tool_response,
255-
})
256-
kwargs["messages"] = messages
257-
258-
return response
268+
return decor

0 commit comments

Comments
 (0)