Skip to content

Make user_prompt optional #1406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):

user_deps: DepsT

prompt: str | Sequence[_messages.UserContent]
prompt: str | Sequence[_messages.UserContent] | None
new_message_index: int

model: models.Model
Expand Down Expand Up @@ -124,7 +124,7 @@ def is_agent_node(

@dataclasses.dataclass
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
user_prompt: str | Sequence[_messages.UserContent]
user_prompt: str | Sequence[_messages.UserContent] | None

system_prompts: tuple[str, ...]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
Expand All @@ -151,7 +151,7 @@ async def _get_first_message(

async def _prepare_messages(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None,
message_history: list[_messages.ModelMessage] | None,
run_context: RunContext[DepsT],
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
Expand All @@ -166,16 +166,18 @@ async def _prepare_messages(
messages = ctx_messages.messages
ctx_messages.used = True

parts: list[_messages.ModelRequestPart] = []
if message_history:
# Shallow copy messages
messages.extend(message_history)
# Reevaluate any dynamic system prompt parts
await self._reevaluate_dynamic_prompts(messages, run_context)
return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
else:
parts = await self._sys_parts(run_context)
parts.extend(await self._sys_parts(run_context))

if user_prompt is not None:
parts.append(_messages.UserPromptPart(user_prompt))
return messages, _messages.ModelRequest(parts)
return messages, _messages.ModelRequest(parts)

async def _reevaluate_dynamic_prompts(
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
Expand Down
14 changes: 7 additions & 7 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
@overload
async def run(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
result_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -257,7 +257,7 @@ async def run(
@overload
async def run(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
result_type: type[RunResultDataT],
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -271,7 +271,7 @@ async def run(

async def run(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -335,7 +335,7 @@ async def main():
@asynccontextmanager
async def iter(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -503,7 +503,7 @@ async def main():
@overload
def run_sync(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
message_history: list[_messages.ModelMessage] | None = None,
model: models.Model | models.KnownModelName | str | None = None,
Expand All @@ -517,7 +517,7 @@ def run_sync(
@overload
def run_sync(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
result_type: type[RunResultDataT] | None,
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -531,7 +531,7 @@ def run_sync(

def run_sync(
self,
user_prompt: str | Sequence[_messages.UserContent],
user_prompt: str | Sequence[_messages.UserContent] | None = None,
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class RunContext(Generic[AgentDepsT]):
"""The model used in this run."""
usage: Usage
"""LLM usage associated with the run."""
prompt: str | Sequence[_messages.UserContent]
prompt: str | Sequence[_messages.UserContent] | None
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
"""Messages exchanged in the conversation so far."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '101'
content-type:
- application/json
host:
- api.openai.com
method: POST
parsed_body:
messages:
- content: You are a potato.
role: system
model: o3-mini
n: 1
stream: false
uri: https://api.openai.com/v1/chat/completions
response:
headers:
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
connection:
- keep-alive
content-length:
- '906'
content-type:
- application/json
openai-organization:
- pydantic-28gund
openai-processing-ms:
- '8045'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
transfer-encoding:
- chunked
parsed_body:
choices:
- finish_reason: stop
index: 0
message:
annotations: []
content: That's right—I am a potato! A spud of many talents, here to help you out. How can this humble potato be
of service today?
refusal: null
role: assistant
created: 1744099208
id: chatcmpl-BJyAKqCjJI3mIdQmTSW6UlG6NKpjm
model: o3-mini-2025-01-31
object: chat.completion
service_tier: default
system_fingerprint: fp_617f206dd9
usage:
completion_tokens: 809
completion_tokens_details:
accepted_prediction_tokens: 0
audio_tokens: 0
reasoning_tokens: 768
rejected_prediction_tokens: 0
prompt_tokens: 11
prompt_tokens_details:
audio_tokens: 0
cached_tokens: 0
total_tokens: 820
status:
code: 200
message: OK
version: 1
10 changes: 10 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,3 +1204,13 @@ class MyModel(BaseModel):
'$ref': '#/$defs/MyModel',
}
)


@pytest.mark.vcr
async def test_openai_model_without_system_prompt(allow_model_requests: None, openai_api_key: str):
m = OpenAIModel('o3-mini', provider=OpenAIProvider(api_key=openai_api_key))
agent = Agent(m, system_prompt='You are a potato.')
result = await agent.run()
assert result.data == snapshot(
"That's right—I am a potato! A spud of many talents, here to help you out. How can this humble potato be of service today?"
)