Skip to content

Commit 42166de

Browse files
committed
fix issues with model and tools params
1 parent f43a0f9 commit 42166de

File tree

4 files changed

+107
-9
lines changed

4 files changed

+107
-9
lines changed

src/agents/models/openai_provider.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,17 @@ def _get_client(self) -> AsyncOpenAI:
8181
return self._client
8282

8383
def get_model(self, model_name: str | None) -> Model:
84-
if model_name is None:
85-
model_name = get_default_model()
84+
model_is_explicit = model_name is not None
85+
resolved_model_name = model_name if model_name is not None else get_default_model()
8686

8787
client = self._get_client()
8888

8989
return (
90-
OpenAIResponsesModel(model=model_name, openai_client=client)
90+
OpenAIResponsesModel(
91+
model=resolved_model_name,
92+
openai_client=client,
93+
model_is_explicit=model_is_explicit,
94+
)
9195
if self._use_responses
92-
else OpenAIChatCompletionsModel(model=model_name, openai_client=client)
96+
else OpenAIChatCompletionsModel(model=resolved_model_name, openai_client=client)
9397
)

src/agents/models/openai_responses.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload
88

9-
from openai import APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit
9+
from openai import APIStatusError, AsyncOpenAI, AsyncStream, NotGiven, Omit, NOT_GIVEN, omit
1010
from openai.types import ChatModel
1111
from openai.types.responses import (
1212
Response,
@@ -67,8 +67,11 @@ def __init__(
6767
self,
6868
model: str | ChatModel,
6969
openai_client: AsyncOpenAI,
70+
*,
71+
model_is_explicit: bool = True,
7072
) -> None:
7173
self.model = model
74+
self._model_is_explicit = model_is_explicit
7275
self._client = openai_client
7376

7477
def _non_null_or_omit(self, value: Any) -> Any:
@@ -262,6 +265,14 @@ async def _fetch_response(
262265
converted_tools = Converter.convert_tools(tools, handoffs)
263266
converted_tools_payload = _to_dump_compatible(converted_tools.tools)
264267
response_format = Converter.get_response_format(output_schema)
268+
should_omit_model = prompt is not None and not self._model_is_explicit
269+
model_param: str | ChatModel | NotGiven = (
270+
self.model if not should_omit_model else NOT_GIVEN
271+
)
272+
should_omit_tools = prompt is not None and len(converted_tools_payload) == 0
273+
tools_param: list[ToolParam] | NotGiven = (
274+
converted_tools_payload if not should_omit_tools else NOT_GIVEN
275+
)
265276

266277
include_set: set[str] = set(converted_tools.includes)
267278
if model_settings.response_include is not None:
@@ -309,10 +320,10 @@ async def _fetch_response(
309320
previous_response_id=self._non_null_or_omit(previous_response_id),
310321
conversation=self._non_null_or_omit(conversation_id),
311322
instructions=self._non_null_or_omit(system_instructions),
312-
model=self.model,
323+
model=model_param,
313324
input=list_input,
314325
include=include,
315-
tools=converted_tools_payload,
326+
tools=tools_param,
316327
prompt=self._non_null_or_omit(prompt),
317328
temperature=self._non_null_or_omit(model_settings.temperature),
318329
top_p=self._non_null_or_omit(model_settings.top_p),

tests/test_agent_prompt.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import pytest
2+
from openai import NOT_GIVEN
23

3-
from agents import Agent, Prompt, RunContextWrapper, Runner
4+
from agents import Agent, Prompt, RunConfig, RunContextWrapper, Runner
5+
from agents.models.interface import Model, ModelProvider
6+
from agents.models.openai_responses import OpenAIResponsesModel
47

58
from .fake_model import FakeModel
9+
from .fake_model import get_response_obj
610
from .test_responses import get_text_message
711

812

@@ -97,3 +101,43 @@ async def test_prompt_is_passed_to_model():
97101
"variables": None,
98102
}
99103
assert model.last_prompt == expected_prompt
104+
105+
106+
class _SingleModelProvider(ModelProvider):
107+
def __init__(self, model: Model):
108+
self._model = model
109+
110+
def get_model(self, model_name: str | None) -> Model:
111+
return self._model
112+
113+
114+
@pytest.mark.allow_call_model_methods
115+
@pytest.mark.asyncio
116+
async def test_agent_prompt_with_default_model_omits_model_and_tools_parameters():
117+
called_kwargs: dict[str, object] = {}
118+
119+
class DummyResponses:
120+
async def create(self, **kwargs):
121+
nonlocal called_kwargs
122+
called_kwargs = kwargs
123+
return get_response_obj([get_text_message("done")])
124+
125+
class DummyResponsesClient:
126+
def __init__(self):
127+
self.responses = DummyResponses()
128+
129+
model = OpenAIResponsesModel(
130+
model="gpt-4.1",
131+
openai_client=DummyResponsesClient(), # type: ignore[arg-type]
132+
model_is_explicit=False,
133+
)
134+
135+
run_config = RunConfig(model_provider=_SingleModelProvider(model))
136+
agent = Agent(name="prompt-agent", prompt={"id": "pmpt_agent"})
137+
138+
await Runner.run(agent, input="hi", run_config=run_config)
139+
140+
expected_prompt = {"id": "pmpt_agent", "version": None, "variables": None}
141+
assert called_kwargs["prompt"] == expected_prompt
142+
assert called_kwargs["model"] is NOT_GIVEN
143+
assert called_kwargs["tools"] is NOT_GIVEN

tests/test_openai_responses.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ class DummyResponsesClient:
8181
def __init__(self):
8282
self.responses = DummyResponses()
8383

84-
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyResponsesClient()) # type: ignore[arg-type]
84+
model = OpenAIResponsesModel(
85+
model="gpt-4",
86+
openai_client=DummyResponsesClient(), # type: ignore[arg-type]
87+
model_is_explicit=False,
88+
)
8589

8690
await model.get_response(
8791
system_instructions=None,
@@ -96,3 +100,38 @@ def __init__(self):
96100

97101
assert called_kwargs["prompt"] == {"id": "pmpt_123"}
98102
assert called_kwargs["model"] is NOT_GIVEN
103+
104+
105+
@pytest.mark.allow_call_model_methods
106+
@pytest.mark.asyncio
107+
async def test_prompt_id_omits_tools_parameter_when_no_tools_configured():
108+
called_kwargs: dict[str, Any] = {}
109+
110+
class DummyResponses:
111+
async def create(self, **kwargs):
112+
nonlocal called_kwargs
113+
called_kwargs = kwargs
114+
return get_response_obj([])
115+
116+
class DummyResponsesClient:
117+
def __init__(self):
118+
self.responses = DummyResponses()
119+
120+
model = OpenAIResponsesModel(
121+
model="gpt-4",
122+
openai_client=DummyResponsesClient(), # type: ignore[arg-type]
123+
model_is_explicit=False,
124+
)
125+
126+
await model.get_response(
127+
system_instructions=None,
128+
input="hi",
129+
model_settings=ModelSettings(),
130+
tools=[],
131+
output_schema=None,
132+
handoffs=[],
133+
tracing=ModelTracing.DISABLED,
134+
prompt={"id": "pmpt_123"},
135+
)
136+
137+
assert called_kwargs["tools"] is NOT_GIVEN

0 commit comments

Comments
 (0)