Skip to content

Commit bc5aca3

Browse files
authored
fix: invalid model setting when passing prompt to Agent (#1852)
1 parent c6569cb commit bc5aca3

File tree

4 files changed

+138
-8
lines changed

4 files changed

+138
-8
lines changed

src/agents/models/openai_provider.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,17 @@ def _get_client(self) -> AsyncOpenAI:
8181
return self._client
8282

8383
def get_model(self, model_name: str | None) -> Model:
84-
if model_name is None:
85-
model_name = get_default_model()
84+
model_is_explicit = model_name is not None
85+
resolved_model_name = model_name if model_name is not None else get_default_model()
8686

8787
client = self._get_client()
8888

8989
return (
90-
OpenAIResponsesModel(model=model_name, openai_client=client)
90+
OpenAIResponsesModel(
91+
model=resolved_model_name,
92+
openai_client=client,
93+
model_is_explicit=model_is_explicit,
94+
)
9195
if self._use_responses
92-
else OpenAIChatCompletionsModel(model=model_name, openai_client=client)
96+
else OpenAIChatCompletionsModel(model=resolved_model_name, openai_client=client)
9397
)

src/agents/models/openai_responses.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ def __init__(
6767
self,
6868
model: str | ChatModel,
6969
openai_client: AsyncOpenAI,
70+
*,
71+
model_is_explicit: bool = True,
7072
) -> None:
7173
self.model = model
74+
self._model_is_explicit = model_is_explicit
7275
self._client = openai_client
7376

7477
def _non_null_or_omit(self, value: Any) -> Any:
@@ -262,6 +265,12 @@ async def _fetch_response(
262265
converted_tools = Converter.convert_tools(tools, handoffs)
263266
converted_tools_payload = _to_dump_compatible(converted_tools.tools)
264267
response_format = Converter.get_response_format(output_schema)
268+
should_omit_model = prompt is not None and not self._model_is_explicit
269+
model_param: str | ChatModel | Omit = self.model if not should_omit_model else omit
270+
should_omit_tools = prompt is not None and len(converted_tools_payload) == 0
271+
tools_param: list[ToolParam] | Omit = (
272+
converted_tools_payload if not should_omit_tools else omit
273+
)
265274

266275
include_set: set[str] = set(converted_tools.includes)
267276
if model_settings.response_include is not None:
@@ -309,10 +318,10 @@ async def _fetch_response(
309318
previous_response_id=self._non_null_or_omit(previous_response_id),
310319
conversation=self._non_null_or_omit(conversation_id),
311320
instructions=self._non_null_or_omit(system_instructions),
312-
model=self.model,
321+
model=model_param,
313322
input=list_input,
314323
include=include,
315-
tools=converted_tools_payload,
324+
tools=tools_param,
316325
prompt=self._non_null_or_omit(prompt),
317326
temperature=self._non_null_or_omit(model_settings.temperature),
318327
top_p=self._non_null_or_omit(model_settings.top_p),

tests/test_agent_prompt.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from __future__ import annotations
2+
13
import pytest
4+
from openai import omit
25

3-
from agents import Agent, Prompt, RunContextWrapper, Runner
6+
from agents import Agent, Prompt, RunConfig, RunContextWrapper, Runner
7+
from agents.models.interface import Model, ModelProvider
8+
from agents.models.openai_responses import OpenAIResponsesModel
49

5-
from .fake_model import FakeModel
10+
from .fake_model import FakeModel, get_response_obj
611
from .test_responses import get_text_message
712

813

@@ -97,3 +102,43 @@ async def test_prompt_is_passed_to_model():
97102
"variables": None,
98103
}
99104
assert model.last_prompt == expected_prompt
105+
106+
107+
class _SingleModelProvider(ModelProvider):
108+
def __init__(self, model: Model):
109+
self._model = model
110+
111+
def get_model(self, model_name: str | None) -> Model:
112+
return self._model
113+
114+
115+
@pytest.mark.allow_call_model_methods
116+
@pytest.mark.asyncio
117+
async def test_agent_prompt_with_default_model_omits_model_and_tools_parameters():
118+
called_kwargs: dict[str, object] = {}
119+
120+
class DummyResponses:
121+
async def create(self, **kwargs):
122+
nonlocal called_kwargs
123+
called_kwargs = kwargs
124+
return get_response_obj([get_text_message("done")])
125+
126+
class DummyResponsesClient:
127+
def __init__(self):
128+
self.responses = DummyResponses()
129+
130+
model = OpenAIResponsesModel(
131+
model="gpt-4.1",
132+
openai_client=DummyResponsesClient(), # type: ignore[arg-type]
133+
model_is_explicit=False,
134+
)
135+
136+
run_config = RunConfig(model_provider=_SingleModelProvider(model))
137+
agent = Agent(name="prompt-agent", prompt={"id": "pmpt_agent"})
138+
139+
await Runner.run(agent, input="hi", run_config=run_config)
140+
141+
expected_prompt = {"id": "pmpt_agent", "version": None, "variables": None}
142+
assert called_kwargs["prompt"] == expected_prompt
143+
assert called_kwargs["model"] is omit
144+
assert called_kwargs["tools"] is omit

tests/test_openai_responses.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any
44

55
import pytest
6+
from openai import omit
67
from openai.types.responses import ResponseCompletedEvent
78

89
from agents import ModelSettings, ModelTracing, __version__
@@ -63,3 +64,74 @@ def __init__(self):
6364

6465
assert "extra_headers" in called_kwargs
6566
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua
67+
68+
69+
@pytest.mark.allow_call_model_methods
70+
@pytest.mark.asyncio
71+
async def test_prompt_id_omits_model_parameter():
72+
called_kwargs: dict[str, Any] = {}
73+
74+
class DummyResponses:
75+
async def create(self, **kwargs):
76+
nonlocal called_kwargs
77+
called_kwargs = kwargs
78+
return get_response_obj([])
79+
80+
class DummyResponsesClient:
81+
def __init__(self):
82+
self.responses = DummyResponses()
83+
84+
model = OpenAIResponsesModel(
85+
model="gpt-4",
86+
openai_client=DummyResponsesClient(), # type: ignore[arg-type]
87+
model_is_explicit=False,
88+
)
89+
90+
await model.get_response(
91+
system_instructions=None,
92+
input="hi",
93+
model_settings=ModelSettings(),
94+
tools=[],
95+
output_schema=None,
96+
handoffs=[],
97+
tracing=ModelTracing.DISABLED,
98+
prompt={"id": "pmpt_123"},
99+
)
100+
101+
assert called_kwargs["prompt"] == {"id": "pmpt_123"}
102+
assert called_kwargs["model"] is omit
103+
104+
105+
@pytest.mark.allow_call_model_methods
106+
@pytest.mark.asyncio
107+
async def test_prompt_id_omits_tools_parameter_when_no_tools_configured():
108+
called_kwargs: dict[str, Any] = {}
109+
110+
class DummyResponses:
111+
async def create(self, **kwargs):
112+
nonlocal called_kwargs
113+
called_kwargs = kwargs
114+
return get_response_obj([])
115+
116+
class DummyResponsesClient:
117+
def __init__(self):
118+
self.responses = DummyResponses()
119+
120+
model = OpenAIResponsesModel(
121+
model="gpt-4",
122+
openai_client=DummyResponsesClient(), # type: ignore[arg-type]
123+
model_is_explicit=False,
124+
)
125+
126+
await model.get_response(
127+
system_instructions=None,
128+
input="hi",
129+
model_settings=ModelSettings(),
130+
tools=[],
131+
output_schema=None,
132+
handoffs=[],
133+
tracing=ModelTracing.DISABLED,
134+
prompt={"id": "pmpt_123"},
135+
)
136+
137+
assert called_kwargs["tools"] is omit

0 commit comments

Comments
 (0)