Skip to content

Run make format #1106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mcp/prompt_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def get_instructions_from_prompt(mcp_server: MCPServer, prompt_name: str,
try:
prompt_result = await mcp_server.get_prompt(prompt_name, kwargs)
content = prompt_result.messages[0].content
if hasattr(content, 'text'):
if hasattr(content, "text"):
instructions = content.text
else:
instructions = str(content)
Expand Down
3 changes: 3 additions & 0 deletions src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,18 @@ def validate_from_none(value: None) -> _Omit:
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
)


@dataclass
class MCPToolChoice:
server_label: str
name: str


Omit = Annotated[_Omit, _OmitTypeAnnotation]
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]


@dataclass
class ModelSettings:
"""Settings to use when calling an LLM.
Expand Down
2 changes: 1 addition & 1 deletion src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def convert_tool_choice(
elif tool_choice == "mcp":
# Note that this is still here for backwards compatibility,
# but migrating to MCPToolChoice is recommended.
return { "type": "mcp" } # type: ignore [typeddict-item]
return {"type": "mcp"} # type: ignore [typeddict-item]
else:
return {
"type": "function",
Expand Down
82 changes: 45 additions & 37 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,27 +974,30 @@ class TestGuardrailFunctionality:
async def _wait_for_guardrail_tasks(self, session):
"""Wait for all pending guardrail tasks to complete."""
import asyncio

if session._guardrail_tasks:
await asyncio.gather(*session._guardrail_tasks, return_exceptions=True)

@pytest.fixture
def triggered_guardrail(self):
"""Creates a guardrail that always triggers"""

def guardrail_func(context, agent, output):
return GuardrailFunctionOutput(
output_info={"reason": "test trigger"},
tripwire_triggered=True
output_info={"reason": "test trigger"}, tripwire_triggered=True
)

return OutputGuardrail(guardrail_function=guardrail_func, name="triggered_guardrail")

@pytest.fixture
def safe_guardrail(self):
"""Creates a guardrail that never triggers"""

def guardrail_func(context, agent, output):
return GuardrailFunctionOutput(
output_info={"reason": "safe content"},
tripwire_triggered=False
output_info={"reason": "safe content"}, tripwire_triggered=False
)

return OutputGuardrail(guardrail_function=guardrail_func, name="safe_guardrail")

@pytest.mark.asyncio
Expand All @@ -1004,7 +1007,7 @@ async def test_transcript_delta_triggers_guardrail_at_threshold(
"""Test that guardrails run when transcript delta reaches debounce threshold"""
run_config: RealtimeRunConfig = {
"output_guardrails": [triggered_guardrail],
"guardrails_settings": {"debounce_text_length": 10}
"guardrails_settings": {"debounce_text_length": 10},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
Expand Down Expand Up @@ -1041,20 +1044,20 @@ async def test_transcript_delta_multiple_thresholds_same_item(
"""Test guardrails run at 1x, 2x, 3x thresholds for same item_id"""
run_config: RealtimeRunConfig = {
"output_guardrails": [triggered_guardrail],
"guardrails_settings": {"debounce_text_length": 5}
"guardrails_settings": {"debounce_text_length": 5},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)

# First delta - reaches 1x threshold (5 chars)
await session.on_event(RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="12345", response_id="resp_1"
))
await session.on_event(
RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="12345", response_id="resp_1")
)

# Second delta - reaches 2x threshold (10 chars total)
await session.on_event(RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="67890", response_id="resp_1"
))
await session.on_event(
RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="67890", response_id="resp_1")
)

# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)
Expand All @@ -1070,28 +1073,32 @@ async def test_transcript_delta_different_items_tracked_separately(
"""Test that different item_ids are tracked separately for debouncing"""
run_config: RealtimeRunConfig = {
"output_guardrails": [safe_guardrail],
"guardrails_settings": {"debounce_text_length": 10}
"guardrails_settings": {"debounce_text_length": 10},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)

# Add text to item_1 (8 chars - below threshold)
await session.on_event(RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="12345678", response_id="resp_1"
))
await session.on_event(
RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="12345678", response_id="resp_1"
)
)

# Add text to item_2 (8 chars - below threshold)
await session.on_event(RealtimeModelTranscriptDeltaEvent(
item_id="item_2", delta="abcdefgh", response_id="resp_2"
))
await session.on_event(
RealtimeModelTranscriptDeltaEvent(
item_id="item_2", delta="abcdefgh", response_id="resp_2"
)
)

# Neither should trigger guardrails yet
assert mock_model.interrupts_called == 0

# Add more text to item_1 (total 12 chars - above threshold)
await session.on_event(RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="90ab", response_id="resp_1"
))
await session.on_event(
RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="90ab", response_id="resp_1")
)

# item_1 should have triggered guardrail run (but not interrupted since safe)
assert session._item_guardrail_run_counts["item_1"] == 1
Expand All @@ -1107,15 +1114,17 @@ async def test_turn_ended_clears_guardrail_state(
"""Test that turn_ended event clears guardrail state for next turn"""
run_config: RealtimeRunConfig = {
"output_guardrails": [triggered_guardrail],
"guardrails_settings": {"debounce_text_length": 5}
"guardrails_settings": {"debounce_text_length": 5},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)

# Trigger guardrail
await session.on_event(RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="trigger", response_id="resp_1"
))
await session.on_event(
RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="trigger", response_id="resp_1"
)
)

# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)
Expand All @@ -1132,31 +1141,30 @@ async def test_turn_ended_clears_guardrail_state(
assert len(session._item_guardrail_run_counts) == 0

@pytest.mark.asyncio
async def test_multiple_guardrails_all_triggered(
self, mock_model, mock_agent
):
async def test_multiple_guardrails_all_triggered(self, mock_model, mock_agent):
"""Test that all triggered guardrails are included in the event"""

def create_triggered_guardrail(name):
def guardrail_func(context, agent, output):
return GuardrailFunctionOutput(
output_info={"name": name},
tripwire_triggered=True
)
return GuardrailFunctionOutput(output_info={"name": name}, tripwire_triggered=True)

return OutputGuardrail(guardrail_function=guardrail_func, name=name)

guardrail1 = create_triggered_guardrail("guardrail_1")
guardrail2 = create_triggered_guardrail("guardrail_2")

run_config: RealtimeRunConfig = {
"output_guardrails": [guardrail1, guardrail2],
"guardrails_settings": {"debounce_text_length": 5}
"guardrails_settings": {"debounce_text_length": 5},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)

await session.on_event(RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="trigger", response_id="resp_1"
))
await session.on_event(
RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="trigger", response_id="resp_1"
)
)

# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)
Expand Down
16 changes: 5 additions & 11 deletions tests/realtime/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,36 +222,30 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket):
# Create a test agent and runner with tracing disabled
agent = RealtimeAgent(name="test_agent", instructions="test")

runner = RealtimeRunner(
starting_agent=agent,
config={"tracing_disabled": True}
)
runner = RealtimeRunner(starting_agent=agent, config={"tracing_disabled": True})

# Test the _get_model_settings method directly since that's where the logic is
model_settings = await runner._get_model_settings(
agent=agent,
disable_tracing=True, # This should come from config["tracing_disabled"]
initial_settings=None,
overrides=None
overrides=None,
)

# When tracing is disabled, model settings should have tracing=None
assert model_settings["tracing"] is None

# Also test that the runner passes disable_tracing=True correctly
with patch.object(runner, '_get_model_settings') as mock_get_settings:
with patch.object(runner, "_get_model_settings") as mock_get_settings:
mock_get_settings.return_value = {"tracing": None}

with patch('agents.realtime.session.RealtimeSession') as mock_session_class:
with patch("agents.realtime.session.RealtimeSession") as mock_session_class:
mock_session = AsyncMock()
mock_session_class.return_value = mock_session

await runner.run()

# Verify that _get_model_settings was called with disable_tracing=True
mock_get_settings.assert_called_once_with(
agent=agent,
disable_tracing=True,
initial_settings=None,
overrides=None
agent=agent, disable_tracing=True, initial_settings=None, overrides=None
)