Skip to content

Commit 5d856e9

Browse files
authored
Merge branch 'main' into update-docs-job
2 parents 906ca67 + 741da67 commit 5d856e9

File tree

16 files changed

+666
-30
lines changed

16 files changed

+666
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires-python = ">=3.9"
77
license = "MIT"
88
authors = [{ name = "OpenAI", email = "support@openai.com" }]
99
dependencies = [
10-
"openai>=1.87.0",
10+
"openai>=1.93.1, <2",
1111
"pydantic>=2.10, <3",
1212
"griffe>=1.5.6, <2",
1313
"typing-extensions>=4.12.2, <5",

src/agents/model_settings.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@ def validate_from_none(value: None) -> _Omit:
4242
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
4343
)
4444

45+
@dataclass
46+
class MCPToolChoice:
47+
server_label: str
48+
name: str
4549

4650
Omit = Annotated[_Omit, _OmitTypeAnnotation]
4751
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
48-
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, None]
49-
52+
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]
5053

5154
@dataclass
5255
class ModelSettings:

src/agents/models/chatcmpl_converter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,20 @@
4444
from ..exceptions import AgentsException, UserError
4545
from ..handoffs import Handoff
4646
from ..items import TResponseInputItem, TResponseOutputItem
47+
from ..model_settings import MCPToolChoice
4748
from ..tool import FunctionTool, Tool
4849
from .fake_id import FAKE_RESPONSES_ID
4950

5051

5152
class Converter:
5253
@classmethod
5354
def convert_tool_choice(
54-
cls, tool_choice: Literal["auto", "required", "none"] | str | None
55+
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None
5556
) -> ChatCompletionToolChoiceOptionParam | NotGiven:
5657
if tool_choice is None:
5758
return NOT_GIVEN
59+
elif isinstance(tool_choice, MCPToolChoice):
60+
raise UserError("MCPToolChoice is not supported for Chat Completions models")
5861
elif tool_choice == "auto":
5962
return "auto"
6063
elif tool_choice == "required":

src/agents/models/openai_responses.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..handoffs import Handoff
2626
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
2727
from ..logger import logger
28+
from ..model_settings import MCPToolChoice
2829
from ..tool import (
2930
CodeInterpreterTool,
3031
ComputerTool,
@@ -303,10 +304,16 @@ class ConvertedTools:
303304
class Converter:
304305
@classmethod
305306
def convert_tool_choice(
306-
cls, tool_choice: Literal["auto", "required", "none"] | str | None
307+
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None
307308
) -> response_create_params.ToolChoice | NotGiven:
308309
if tool_choice is None:
309310
return NOT_GIVEN
311+
elif isinstance(tool_choice, MCPToolChoice):
312+
return {
313+
"server_label": tool_choice.server_label,
314+
"type": "mcp",
315+
"name": tool_choice.name,
316+
}
310317
elif tool_choice == "required":
311318
return "required"
312319
elif tool_choice == "auto":
@@ -334,9 +341,9 @@ def convert_tool_choice(
334341
"type": "code_interpreter",
335342
}
336343
elif tool_choice == "mcp":
337-
return {
338-
"type": "mcp",
339-
}
344+
# Note that this is still here for backwards compatibility,
345+
# but migrating to MCPToolChoice is recommended.
346+
return { "type": "mcp" } # type: ignore [typeddict-item]
340347
else:
341348
return {
342349
"type": "function",

src/agents/realtime/config.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

11+
from ..guardrail import OutputGuardrail
1112
from ..model_settings import ToolChoice
1213
from ..tool import Tool
1314

@@ -82,11 +83,43 @@ class RealtimeSessionModelSettings(TypedDict):
8283
tool_choice: NotRequired[ToolChoice]
8384
tools: NotRequired[list[Tool]]
8485

86+
tracing: NotRequired[RealtimeModelTracingConfig | None]
87+
88+
89+
class RealtimeGuardrailsSettings(TypedDict):
90+
"""Settings for output guardrails in realtime sessions."""
91+
92+
debounce_text_length: NotRequired[int]
93+
"""
94+
The minimum number of characters to accumulate before running guardrails on transcript
95+
deltas. Defaults to 100. Guardrails run every time the accumulated text reaches
96+
1x, 2x, 3x, etc. times this threshold.
97+
"""
98+
99+
100+
class RealtimeModelTracingConfig(TypedDict):
101+
"""Configuration for tracing in realtime model sessions."""
102+
103+
workflow_name: NotRequired[str]
104+
"""The workflow name to use for tracing."""
105+
106+
group_id: NotRequired[str]
107+
"""A group identifier to use for tracing, to link multiple traces together."""
108+
109+
metadata: NotRequired[dict[str, Any]]
110+
"""Additional metadata to include with the trace."""
111+
85112

86113
class RealtimeRunConfig(TypedDict):
87114
model_settings: NotRequired[RealtimeSessionModelSettings]
88115

89-
# TODO (rm) Add tracing support
90-
# tracing: NotRequired[RealtimeTracingConfig | None]
91-
# TODO (rm) Add guardrail support
116+
output_guardrails: NotRequired[list[OutputGuardrail[Any]]]
117+
"""List of output guardrails to run on the agent's responses."""
118+
119+
guardrails_settings: NotRequired[RealtimeGuardrailsSettings]
120+
"""Settings for guardrail execution."""
121+
122+
tracing_disabled: NotRequired[bool]
123+
"""Whether tracing is disabled for this run."""
124+
92125
# TODO (rm) Add history audio storage config

src/agents/realtime/events.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from typing_extensions import TypeAlias
77

8+
from ..guardrail import OutputGuardrailResult
89
from ..run_context import RunContextWrapper
910
from ..tool import Tool
1011
from .agent import RealtimeAgent
@@ -181,7 +182,20 @@ class RealtimeHistoryAdded:
181182
type: Literal["history_added"] = "history_added"
182183

183184

184-
# TODO (rm) Add guardrails
185+
@dataclass
186+
class RealtimeGuardrailTripped:
187+
"""A guardrail has been tripped and the agent has been interrupted."""
188+
189+
guardrail_results: list[OutputGuardrailResult]
190+
"""The results from all triggered guardrails."""
191+
192+
message: str
193+
"""The message that was being generated when the guardrail was triggered."""
194+
195+
info: RealtimeEventInfo
196+
"""Common info for all events, such as the context."""
197+
198+
type: Literal["guardrail_tripped"] = "guardrail_tripped"
185199

186200
RealtimeSessionEvent: TypeAlias = Union[
187201
RealtimeAgentStartEvent,
@@ -196,5 +210,6 @@ class RealtimeHistoryAdded:
196210
RealtimeError,
197211
RealtimeHistoryUpdated,
198212
RealtimeHistoryAdded,
213+
RealtimeGuardrailTripped,
199214
]
200215
"""An event emitted by the realtime session."""

src/agents/realtime/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class RealtimeModelConfig(TypedDict):
3838
"""
3939

4040
initial_model_settings: NotRequired[RealtimeSessionModelSettings]
41+
"""The initial model settings to use when connecting."""
4142

4243

4344
class RealtimeModel(abc.ABC):

src/agents/realtime/openai_realtime.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import os
88
from datetime import datetime
9-
from typing import Any, Callable
9+
from typing import Any, Callable, Literal
1010

1111
import websockets
1212
from openai.types.beta.realtime.conversation_item import ConversationItem
@@ -23,6 +23,7 @@
2323
from ..logger import logger
2424
from .config import (
2525
RealtimeClientMessage,
26+
RealtimeModelTracingConfig,
2627
RealtimeSessionModelSettings,
2728
RealtimeUserInput,
2829
)
@@ -73,6 +74,7 @@ def __init__(self) -> None:
7374
self._audio_length_ms: float = 0.0
7475
self._ongoing_response: bool = False
7576
self._current_audio_content_index: int | None = None
77+
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
7678

7779
async def connect(self, options: RealtimeModelConfig) -> None:
7880
"""Establish a connection to the model and keep it alive."""
@@ -84,6 +86,11 @@ async def connect(self, options: RealtimeModelConfig) -> None:
8486
self.model = model_settings.get("model_name", self.model)
8587
api_key = await get_api_key(options.get("api_key"))
8688

89+
if "tracing" in model_settings:
90+
self._tracing_config = model_settings["tracing"]
91+
else:
92+
self._tracing_config = "auto"
93+
8794
if not api_key:
8895
raise UserError("API key is required but was not provided.")
8996

@@ -96,6 +103,15 @@ async def connect(self, options: RealtimeModelConfig) -> None:
96103
self._websocket = await websockets.connect(url, additional_headers=headers)
97104
self._websocket_task = asyncio.create_task(self._listen_for_messages())
98105

106+
async def _send_tracing_config(
107+
self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
108+
) -> None:
109+
"""Update tracing configuration via session.update event."""
110+
if tracing_config is not None:
111+
await self.send_event(
112+
{"type": "session.update", "other_data": {"session": {"tracing": tracing_config}}}
113+
)
114+
99115
def add_listener(self, listener: RealtimeModelListener) -> None:
100116
"""Add a listener to the model."""
101117
self._listeners.append(listener)
@@ -343,8 +359,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
343359
self._ongoing_response = False
344360
await self._emit_event(RealtimeModelTurnEndedEvent())
345361
elif parsed.type == "session.created":
346-
# TODO (rm) tracing stuff here
347-
pass
362+
await self._send_tracing_config(self._tracing_config)
348363
elif parsed.type == "error":
349364
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
350365
elif parsed.type == "conversation.item.deleted":

src/agents/realtime/runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ async def run(
6969
"""
7070
model_settings = await self._get_model_settings(
7171
agent=self._starting_agent,
72+
disable_tracing=self._config.get("tracing_disabled", False) if self._config else False,
7273
initial_settings=model_config.get("initial_model_settings") if model_config else None,
7374
overrides=self._config.get("model_settings") if self._config else None,
7475
)
@@ -82,13 +83,15 @@ async def run(
8283
agent=self._starting_agent,
8384
context=context,
8485
model_config=model_config,
86+
run_config=self._config,
8587
)
8688

8789
return session
8890

8991
async def _get_model_settings(
9092
self,
9193
agent: RealtimeAgent,
94+
disable_tracing: bool,
9295
context: TContext | None = None,
9396
initial_settings: RealtimeSessionModelSettings | None = None,
9497
overrides: RealtimeSessionModelSettings | None = None,
@@ -109,4 +112,7 @@ async def _get_model_settings(
109112
if overrides:
110113
model_settings.update(overrides)
111114

115+
if disable_tracing:
116+
model_settings["tracing"] = None
117+
112118
return model_settings

0 commit comments

Comments
 (0)