Skip to content

Commit 20139cc

Browse files
committed
Realtime: handoffs
1 parent 6293d66 commit 20139cc

File tree

12 files changed

+132
-48
lines changed

12 files changed

+132
-48
lines changed

src/agents/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class Agent(AgentBase, Generic[TContext]):
158158
usable with OpenAI models, using the Responses API.
159159
"""
160160

161-
handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
161+
handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list)
162162
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
163163
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
164164
modularity.

src/agents/guardrail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def decorator(
244244
return InputGuardrail(
245245
guardrail_function=f,
246246
# If not set, guardrail name uses the function’s name by default.
247-
name=name if name else f.__name__
247+
name=name if name else f.__name__,
248248
)
249249

250250
if func is not None:

src/agents/handoffs.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
from .util._types import MaybeAwaitable
1919

2020
if TYPE_CHECKING:
21-
from .agent import Agent
21+
from .agent import Agent, AgentBase
2222

2323

2424
# The handoff input type is the type of data passed when the agent is called via a handoff.
2525
THandoffInput = TypeVar("THandoffInput", default=Any)
2626

27+
# The agent type that the handoff returns
28+
TAgent = TypeVar("TAgent", bound="AgentBase[Any]", default="Agent[Any]")
29+
2730
OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
2831
OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]
2932

@@ -52,7 +55,7 @@ class HandoffInputData:
5255

5356

5457
@dataclass
55-
class Handoff(Generic[TContext]):
58+
class Handoff(Generic[TContext, TAgent]):
5659
"""A handoff is when an agent delegates a task to another agent.
5760
For example, in a customer support scenario you might have a "triage agent" that determines
5861
which agent should handle the user's request, and sub-agents that specialize in different
@@ -69,7 +72,7 @@ class Handoff(Generic[TContext]):
6972
"""The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
7073
"""
7174

72-
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[Agent[TContext]]]
75+
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[TAgent]]
7376
"""The function that invokes the handoff. The parameters passed are:
7477
1. The handoff run context
7578
2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
@@ -100,20 +103,22 @@ class Handoff(Generic[TContext]):
100103
True, as it increases the likelihood of correct JSON input.
101104
"""
102105

103-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
106+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
107+
True
108+
)
104109
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105110
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106111
a handoff based on your context/state."""
107112

108-
def get_transfer_message(self, agent: Agent[Any]) -> str:
113+
def get_transfer_message(self, agent: AgentBase[Any]) -> str:
109114
return json.dumps({"assistant": agent.name})
110115

111116
@classmethod
112-
def default_tool_name(cls, agent: Agent[Any]) -> str:
117+
def default_tool_name(cls, agent: AgentBase[Any]) -> str:
113118
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
114119

115120
@classmethod
116-
def default_tool_description(cls, agent: Agent[Any]) -> str:
121+
def default_tool_description(cls, agent: AgentBase[Any]) -> str:
117122
return (
118123
f"Handoff to the {agent.name} agent to handle the request. "
119124
f"{agent.handoff_description or ''}"
@@ -128,7 +133,7 @@ def handoff(
128133
tool_description_override: str | None = None,
129134
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130135
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
131-
) -> Handoff[TContext]: ...
136+
) -> Handoff[TContext, Agent[TContext]]: ...
132137

133138

134139
@overload
@@ -141,7 +146,7 @@ def handoff(
141146
tool_name_override: str | None = None,
142147
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143148
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
144-
) -> Handoff[TContext]: ...
149+
) -> Handoff[TContext, Agent[TContext]]: ...
145150

146151

147152
@overload
@@ -153,7 +158,7 @@ def handoff(
153158
tool_name_override: str | None = None,
154159
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155160
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
156-
) -> Handoff[TContext]: ...
161+
) -> Handoff[TContext, Agent[TContext]]: ...
157162

158163

159164
def handoff(
@@ -163,8 +168,9 @@ def handoff(
163168
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
164169
input_type: type[THandoffInput] | None = None,
165170
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
167-
) -> Handoff[TContext]:
171+
is_enabled: bool
172+
| Callable[[RunContextWrapper[Any], Agent[TContext]], MaybeAwaitable[bool]] = True,
173+
) -> Handoff[TContext, Agent[TContext]]:
168174
"""Create a handoff from an agent.
169175
170176
Args:
@@ -202,7 +208,7 @@ def handoff(
202208

203209
async def _invoke_handoff(
204210
ctx: RunContextWrapper[Any], input_json: str | None = None
205-
) -> Agent[Any]:
211+
) -> Agent[TContext]:
206212
if input_type is not None and type_adapter is not None:
207213
if input_json is None:
208214
_error_tracing.attach_error_to_current_span(
@@ -239,12 +245,24 @@ async def _invoke_handoff(
239245
# If there is a need, we can make this configurable in the future
240246
input_json_schema = ensure_strict_json_schema(input_json_schema)
241247

248+
async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool:
249+
from .agent import Agent
250+
251+
assert callable(is_enabled), "is_enabled must be non-null here"
252+
assert isinstance(agent_base, Agent), "Can't handoff to a non-Agent"
253+
result = is_enabled(ctx, agent_base)
254+
255+
if inspect.isawaitable(result):
256+
return await result
257+
258+
return result
259+
242260
return Handoff(
243261
tool_name=tool_name,
244262
tool_description=tool_description,
245263
input_json_schema=input_json_schema,
246264
on_invoke_handoff=_invoke_handoff,
247265
input_filter=input_filter,
248266
agent_name=agent.name,
249-
is_enabled=is_enabled,
267+
is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
250268
)

src/agents/models/chatcmpl_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam:
484484
)
485485

486486
@classmethod
487-
def convert_handoff_tool(cls, handoff: Handoff[Any]) -> ChatCompletionToolParam:
487+
def convert_handoff_tool(cls, handoff: Handoff[Any, Any]) -> ChatCompletionToolParam:
488488
return {
489489
"type": "function",
490490
"function": {

src/agents/models/openai_responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def get_response_format(
370370
def convert_tools(
371371
cls,
372372
tools: list[Tool],
373-
handoffs: list[Handoff[Any]],
373+
handoffs: list[Handoff[Any, Any]],
374374
) -> ConvertedTools:
375375
converted_tools: list[ToolParam] = []
376376
includes: list[ResponseIncludable] = []

src/agents/realtime/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RealtimeToolEnd,
3131
RealtimeToolStart,
3232
)
33+
from .handoffs import realtime_handoff
3334
from .items import (
3435
AssistantMessageItem,
3536
AssistantText,
@@ -92,6 +93,8 @@
9293
"RealtimeAgentHooks",
9394
"RealtimeRunHooks",
9495
"RealtimeRunner",
96+
# Handoffs
97+
"realtime_handoff",
9598
# Config
9699
"RealtimeAudioFormat",
97100
"RealtimeClientMessage",

src/agents/realtime/agent.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import dataclasses
44
import inspect
55
from collections.abc import Awaitable
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from typing import Any, Callable, Generic, cast
88

99
from ..agent import AgentBase
10+
from ..handoffs import Handoff
1011
from ..lifecycle import AgentHooksBase, RunHooksBase
1112
from ..logger import logger
1213
from ..run_context import RunContextWrapper, TContext
@@ -53,6 +54,14 @@ class RealtimeAgent(AgentBase, Generic[TContext]):
5354
return a string.
5455
"""
5556

57+
handoffs: list[RealtimeAgent[Any] | Handoff[TContext, RealtimeAgent[Any]]] = field(
58+
default_factory=list
59+
)
60+
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
61+
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
62+
modularity.
63+
"""
64+
5665
hooks: RealtimeAgentHooks | None = None
5766
"""A class that receives callbacks on various lifecycle events for this agent.
5867
"""

src/agents/realtime/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

1111
from ..guardrail import OutputGuardrail
12+
from ..handoffs import Handoff
1213
from ..model_settings import ToolChoice
1314
from ..tool import Tool
1415

@@ -71,6 +72,7 @@ class RealtimeSessionModelSettings(TypedDict):
7172

7273
tool_choice: NotRequired[ToolChoice]
7374
tools: NotRequired[list[Tool]]
75+
handoffs: NotRequired[list[Handoff]]
7476

7577
tracing: NotRequired[RealtimeModelTracingConfig | None]
7678

src/agents/realtime/openai_realtime.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from typing_extensions import assert_never
5757
from websockets.asyncio.client import ClientConnection
5858

59+
from agents.handoffs import Handoff
5960
from agents.tool import FunctionTool, Tool
6061
from agents.util._types import MaybeAwaitable
6162

@@ -519,10 +520,14 @@ def _get_session_config(
519520
"tool_choice",
520521
DEFAULT_MODEL_SETTINGS.get("tool_choice"), # type: ignore
521522
),
522-
tools=self._tools_to_session_tools(model_settings.get("tools", [])),
523+
tools=self._tools_to_session_tools(
524+
tools=model_settings.get("tools", []), handoffs=model_settings.get("handoffs", [])
525+
),
523526
)
524527

525-
def _tools_to_session_tools(self, tools: list[Tool]) -> list[OpenAISessionTool]:
528+
def _tools_to_session_tools(
529+
self, tools: list[Tool], handoffs: list[Handoff]
530+
) -> list[OpenAISessionTool]:
526531
converted_tools: list[OpenAISessionTool] = []
527532
for tool in tools:
528533
if not isinstance(tool, FunctionTool):
@@ -535,6 +540,17 @@ def _tools_to_session_tools(self, tools: list[Tool]) -> list[OpenAISessionTool]:
535540
type="function",
536541
)
537542
)
543+
544+
for handoff in handoffs:
545+
converted_tools.append(
546+
OpenAISessionTool(
547+
name=handoff.tool_name,
548+
description=handoff.tool_description,
549+
parameters=handoff.input_json_schema,
550+
type="function",
551+
)
552+
)
553+
538554
return converted_tools
539555

540556

src/agents/realtime/session.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import inspect
45
from collections.abc import AsyncIterator
56
from typing import Any, cast
67

@@ -31,6 +32,7 @@
3132
RealtimeToolEnd,
3233
RealtimeToolStart,
3334
)
35+
from .handoffs import realtime_handoff
3436
from .items import InputAudio, InputText, RealtimeItem
3537
from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener
3638
from .model_events import (
@@ -255,9 +257,12 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None:
255257

256258
async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
257259
"""Handle a tool call event."""
258-
all_tools = await self._current_agent.get_all_tools(self._context_wrapper)
259-
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
260-
handoff_map = {tool.name: tool for tool in all_tools if isinstance(tool, Handoff)}
260+
tools, handoffs = await asyncio.gather(
261+
self._current_agent.get_all_tools(self._context_wrapper),
262+
self._get_handoffs(self._current_agent, self._context_wrapper),
263+
)
264+
function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
265+
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
261266

262267
if event.name in function_map:
263268
await self._put_event(
@@ -303,7 +308,9 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
303308
# Execute the handoff to get the new agent
304309
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
305310
if not isinstance(result, RealtimeAgent):
306-
raise UserError(f"Handoff {handoff.name} returned invalid result: {type(result)}")
311+
raise UserError(
312+
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
313+
)
307314

308315
# Store previous agent for event
309316
previous_agent = self._current_agent
@@ -492,11 +499,37 @@ async def _get__updated_model_settings(
492499
self, new_agent: RealtimeAgent
493500
) -> RealtimeSessionModelSettings:
494501
updated_settings: RealtimeSessionModelSettings = {}
495-
instructions, tools = await asyncio.gather(
502+
instructions, tools, handoffs = await asyncio.gather(
496503
new_agent.get_system_prompt(self._context_wrapper),
497504
new_agent.get_all_tools(self._context_wrapper),
505+
self._get_handoffs(new_agent, self._context_wrapper),
498506
)
499507
updated_settings["instructions"] = instructions or ""
500508
updated_settings["tools"] = tools or []
509+
updated_settings["handoffs"] = handoffs or []
501510

502511
return updated_settings
512+
513+
@classmethod
514+
async def _get_handoffs(
515+
cls, agent: RealtimeAgent[Any], context_wrapper: RunContextWrapper[Any]
516+
) -> list[Handoff[Any, RealtimeAgent[Any]]]:
517+
handoffs: list[Handoff[Any, RealtimeAgent[Any]]] = []
518+
for handoff_item in agent.handoffs:
519+
if isinstance(handoff_item, Handoff):
520+
handoffs.append(handoff_item)
521+
elif isinstance(handoff_item, RealtimeAgent):
522+
handoffs.append(realtime_handoff(handoff_item))
523+
524+
async def _check_handoff_enabled(handoff_obj: Handoff[Any, RealtimeAgent[Any]]) -> bool:
525+
attr = handoff_obj.is_enabled
526+
if isinstance(attr, bool):
527+
return attr
528+
res = attr(context_wrapper, agent)
529+
if inspect.isawaitable(res):
530+
return await res
531+
return res
532+
533+
results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
534+
enabled = [h for h, ok in zip(handoffs, results) if ok]
535+
return enabled

0 commit comments

Comments
 (0)