Skip to content

Commit 459aa61

Browse files
feat: Add system tool calls to the chat endpoint
1 parent fc3d2eb commit 459aa61

File tree

2 files changed

+264
-1
lines changed

2 files changed

+264
-1
lines changed

agents-api/agents_api/routers/sessions/chat.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...queries.entries.create_entries import create_entries
2525
from ...queries.sessions.count_sessions import count_sessions as count_sessions_query
2626
from ..utils.model_validation import validate_model
27+
from ..utils.tools import eval_tool_calls
2728
from .metrics import total_tokens_per_user
2829
from .router import router
2930

@@ -202,7 +203,9 @@ async def chat(
202203
"tags": developer.tags,
203204
"custom_api_key": x_custom_api_key,
204205
}
205-
model_response = await litellm.acompletion(**{**settings, **params})
206+
model_response = await eval_tool_calls(
207+
litellm.acompletion, {"system"}, developer.id, **{**settings, **params}
208+
)
206209

207210
# Save the input and the response to the session history
208211
if chat_input.save:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import json
2+
from collections.abc import Awaitable, Callable
3+
from functools import partial
4+
from typing import Any
5+
from uuid import UUID
6+
7+
from beartype import beartype
8+
from fastapi.background import BackgroundTasks
9+
from litellm.utils import CustomStreamWrapper, ModelResponse
10+
11+
from ...app import app
12+
from ...autogen.openapi_model import (
13+
ChatInput,
14+
CreateDocRequest,
15+
CreateSessionRequest,
16+
HybridDocSearchRequest,
17+
TextOnlyDocSearchRequest,
18+
UpdateSessionRequest,
19+
UpdateUserRequest,
20+
VectorDocSearchRequest,
21+
)
22+
from ...queries.agents.create_agent import create_agent as create_agent_query
23+
from ...queries.agents.delete_agent import delete_agent as delete_agent_query
24+
from ...queries.agents.get_agent import get_agent as get_agent_query
25+
from ...queries.agents.list_agents import list_agents as list_agents_query
26+
from ...queries.agents.update_agent import update_agent as update_agent_query
27+
from ...queries.developers import get_developer
28+
from ...queries.docs.delete_doc import delete_doc as delete_doc_query
29+
from ...queries.docs.list_docs import list_docs as list_docs_query
30+
from ...queries.entries.get_history import get_history as get_history_query
31+
from ...queries.sessions.create_session import create_session as create_session_query
32+
from ...queries.sessions.get_session import get_session as get_session_query
33+
from ...queries.sessions.list_sessions import list_sessions as list_sessions_query
34+
from ...queries.sessions.update_session import update_session as update_session_query
35+
from ...queries.tasks.create_task import create_task as create_task_query
36+
from ...queries.tasks.delete_task import delete_task as delete_task_query
37+
from ...queries.tasks.get_task import get_task as get_task_query
38+
from ...queries.tasks.list_tasks import list_tasks as list_tasks_query
39+
from ...queries.tasks.update_task import update_task as update_task_query
40+
from ...queries.users.create_user import create_user as create_user_query
41+
from ...queries.users.delete_user import delete_user as delete_user_query
42+
from ...queries.users.get_user import get_user as get_user_query
43+
from ...queries.users.list_users import list_users as list_users_query
44+
from ...queries.users.update_user import update_user as update_user_query
45+
46+
# FIXME: Do not use routes directly;
47+
from ...routers.docs.create_doc import create_agent_doc, create_user_doc
48+
from ...routers.docs.search_docs import search_agent_docs, search_user_docs
49+
from ...routers.sessions.chat import chat
50+
51+
MIN_TOOL_NAME_SEGMENTS = 2
52+
53+
54+
_system_tool_handlers = {
55+
"agent.doc.list": list_docs_query,
56+
"agent.doc.create": create_agent_doc,
57+
"agent.doc.delete": delete_doc_query,
58+
"agent.doc.search": search_agent_docs,
59+
"agent.list": list_agents_query,
60+
"agent.get": get_agent_query,
61+
"agent.create": create_agent_query,
62+
"agent.update": update_agent_query,
63+
"agent.delete": delete_agent_query,
64+
"user.doc.list": list_docs_query,
65+
"user.doc.create": create_user_doc,
66+
"user.doc.delete": delete_doc_query,
67+
"user.doc.search": search_user_docs,
68+
"user.list": list_users_query,
69+
"user.get": get_user_query,
70+
"user.create": create_user_query,
71+
"user.update": update_user_query,
72+
"user.delete": delete_user_query,
73+
"session.list": list_sessions_query,
74+
"session.get": get_session_query,
75+
"session.create": create_session_query,
76+
"session.update": update_session_query,
77+
"session.chat": chat,
78+
"session.history": get_history_query,
79+
"task.list": list_tasks_query,
80+
"task.get": get_task_query,
81+
"task.create": create_task_query,
82+
"task.update": update_task_query,
83+
"task.delete": delete_task_query,
84+
}
85+
86+
87+
def _create_search_request(arguments: dict) -> Any:
88+
"""Create appropriate search request based on available parameters."""
89+
if "text" in arguments and "vector" in arguments:
90+
return HybridDocSearchRequest(
91+
text=arguments.pop("text"),
92+
mmr_strength=arguments.pop("mmr_strength", 0),
93+
vector=arguments.pop("vector"),
94+
alpha=arguments.pop("alpha", 0.75),
95+
confidence=arguments.pop("confidence", 0.5),
96+
limit=arguments.get("limit", 10),
97+
)
98+
if "text" in arguments:
99+
return TextOnlyDocSearchRequest(
100+
text=arguments.pop("text"),
101+
mmr_strength=arguments.pop("mmr_strength", 0),
102+
limit=arguments.get("limit", 10),
103+
)
104+
if "vector" in arguments:
105+
return VectorDocSearchRequest(
106+
vector=arguments.pop("vector"),
107+
mmr_strength=arguments.pop("mmr_strength", 0),
108+
confidence=arguments.pop("confidence", 0.7),
109+
limit=arguments.get("limit", 10),
110+
)
111+
return None
112+
113+
114+
@beartype
115+
async def call_tool(developer_id: UUID, tool_name: str, arguments: dict):
116+
tool_handler = _system_tool_handlers.get(tool_name)
117+
if not tool_handler:
118+
msg = f"System call not implemented for {tool_name}"
119+
raise NotImplementedError(msg)
120+
121+
connection_pool = getattr(app.state, "postgres_pool", None)
122+
tool_handler = partial(tool_handler, connection_pool=connection_pool)
123+
arguments["developer_id"] = str(developer_id)
124+
125+
# Convert all UUIDs to UUID objects
126+
uuid_fields = ["agent_id", "user_id", "task_id", "session_id", "doc_id"]
127+
for field in uuid_fields:
128+
if field in arguments:
129+
arguments[field] = UUID(arguments[field])
130+
131+
parts = tool_name.split(".")
132+
if len(parts) < MIN_TOOL_NAME_SEGMENTS:
133+
msg = f"wrong syste tool name: {tool_name}"
134+
raise NameError(msg)
135+
136+
resource, subresource, operation = parts[0], None, parts[-1]
137+
if len(parts) > MIN_TOOL_NAME_SEGMENTS:
138+
subresource = parts[1]
139+
140+
if subresource == "doc" and operation not in ["create", "search"]:
141+
owner_id_field = f"{resource}_id"
142+
if owner_id_field in arguments:
143+
doc_args = {
144+
"owner_type": resource,
145+
"owner_id": arguments[owner_id_field],
146+
**arguments,
147+
}
148+
doc_args.pop(owner_id_field)
149+
arguments = doc_args
150+
151+
# Handle special cases for doc operations
152+
if operation == "create" and subresource == "doc":
153+
arguments["x_developer_id"] = arguments.pop("developer_id")
154+
return await tool_handler(
155+
data=CreateDocRequest(**arguments.pop("data")),
156+
**arguments,
157+
)
158+
159+
# Handle search operations
160+
if operation == "search" and subresource == "doc":
161+
arguments["x_developer_id"] = arguments.pop("developer_id")
162+
search_params = _create_search_request(arguments)
163+
return await tool_handler(search_params=search_params, **arguments)
164+
165+
# Handle chat operations
166+
if operation == "chat" and resource == "session":
167+
developer = await get_developer(
168+
developer_id=arguments["developer_id"],
169+
connection_pool=connection_pool,
170+
) # type: ignore[not-callable]
171+
172+
session_id = arguments.get("session_id")
173+
x_custom_api_key = arguments.get("x_custom_api_key", None)
174+
chat_input = ChatInput(**arguments)
175+
bg_runner = BackgroundTasks()
176+
res = await tool_handler(
177+
developer=developer,
178+
session_id=session_id,
179+
background_tasks=bg_runner,
180+
x_custom_api_key=x_custom_api_key,
181+
chat_input=chat_input,
182+
)
183+
await bg_runner()
184+
return res
185+
186+
# Handle create session
187+
if operation == "create" and resource == "session":
188+
developer_id = arguments.pop("developer_id")
189+
session_id = arguments.pop("session_id", None)
190+
create_session_request = CreateSessionRequest(**arguments)
191+
192+
return await tool_handler(
193+
developer_id=developer_id,
194+
session_id=session_id,
195+
data=create_session_request,
196+
)
197+
198+
# Handle update session
199+
if operation == "update" and resource == "session":
200+
developer_id = arguments.pop("developer_id")
201+
session_id = arguments.pop("session_id")
202+
update_session_request = UpdateSessionRequest(**arguments)
203+
204+
return await tool_handler(
205+
developer_id=developer_id,
206+
session_id=session_id,
207+
data=update_session_request,
208+
)
209+
210+
# Handle update user
211+
if operation == "update" and resource == "user":
212+
developer_id = arguments.pop("developer_id")
213+
user_id = arguments.pop("user_id")
214+
update_user_request = UpdateUserRequest(**arguments)
215+
216+
return await tool_handler(
217+
developer_id=developer_id,
218+
user_id=user_id,
219+
data=update_user_request,
220+
)
221+
222+
return await tool_handler(**arguments)
223+
224+
225+
async def eval_tool_calls(
226+
func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]],
227+
tool_types: set[str],
228+
developer_id: UUID,
229+
**kwargs,
230+
):
231+
response: ModelResponse | CustomStreamWrapper | None = None
232+
done = False
233+
while not done:
234+
response: ModelResponse | CustomStreamWrapper = await func(**kwargs)
235+
if not response.choices or not response.choices[0].message.tool_calls:
236+
return response
237+
238+
# TODO: add streaming response handling
239+
for tool in response.choices[0].message.tool_calls:
240+
if tool.type not in tool_types:
241+
done = True
242+
continue
243+
244+
done = False
245+
# call a tool
246+
tool_name = tool.function.name
247+
tool_args = json.loads(tool.function.arguments)
248+
tool_response = await call_tool(developer_id, tool_name, tool_args)
249+
250+
# append result to messages from previous step
251+
messages: list = kwargs.get("messages", [])
252+
messages.append({
253+
"tool_call_id": tool.id,
254+
"role": "tool",
255+
"name": tool_name,
256+
"content": tool_response,
257+
})
258+
kwargs["messages"] = messages
259+
260+
return response

0 commit comments

Comments
 (0)