Skip to content

Commit 9e2b7d7

Browse files
authored
Merge branch 'main' into mistral-provider
2 parents e3cd1ac + 5579347 commit 9e2b7d7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2633
-1100
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ jobs:
212212
path: index.html
213213

214214
- run: uv run coverage report --fail-under 95
215-
- run: uv run diff-cover coverage.xml --fail-under 95
215+
- run: uv run diff-cover coverage.xml --fail-under 100
216216

217217
# https://github.com/marketplace/actions/alls-green#why used for branch protection checks
218218
check:

Makefile

+5-5
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ test: ## Run tests and collect coverage data
4949

5050
.PHONY: test-all-python
5151
test-all-python: ## Run tests on Python 3.9 to 3.13
52-
UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 --all-extras coverage run -p -m pytest
53-
UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras coverage run -p -m pytest
54-
UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras coverage run -p -m pytest
55-
UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras coverage run -p -m pytest
56-
UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras coverage run -p -m pytest
52+
UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 --all-extras --all-packages coverage run -p -m pytest
53+
UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras --all-packages coverage run -p -m pytest
54+
UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras --all-packages coverage run -p -m pytest
55+
UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras --all-packages coverage run -p -m pytest
56+
UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras --all-packages coverage run -p -m pytest
5757
@uv run coverage combine
5858
@uv run coverage report
5959

docs/api/providers.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
::: pydantic_ai.providers.bedrock
1212

1313
::: pydantic_ai.providers.groq
14+
15+
::: pydantic_ai.providers.azure

docs/api/pydantic_graph/nodes.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
::: pydantic_graph.nodes
44
options:
55
members:
6+
- StateT
67
- GraphRunContext
78
- BaseNode
89
- End
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# `pydantic_graph.persistence`
2+
3+
::: pydantic_graph.persistence
4+
5+
::: pydantic_graph.persistence.in_mem
6+
7+
::: pydantic_graph.persistence.file

docs/api/pydantic_graph/state.md

-3
This file was deleted.

docs/graph.md

+196-373
Large diffs are not rendered by default.

docs/logfire.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ agent = Agent('openai:gpt-4o', instrument=instrumentation_settings)
151151
Agent.instrument_all(instrumentation_settings)
152152
```
153153

154-
For now, this won't look as good in the Logfire UI, but we're working on it. **Once the UI supports it, `event_mode='logs'` will become the default.**
154+
For now, this won't look as good in the Logfire UI, but we're working on it.
155155

156156
If you have very long conversations, the `events` span attribute may be truncated. Using `event_mode='logs'` will help avoid this issue.
157157

docs/models.md

+22
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,28 @@ Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details
828828
1. The name of the model running on the remote server
829829
2. The url of the remote server
830830

831+
### Azure AI Foundry
832+
833+
If you want to use [Azure AI Foundry](https://ai.azure.com/) as your provider, you can do so by using the
834+
[`AzureProvider`][pydantic_ai.providers.azure.AzureProvider] class.
835+
836+
```python {title="azure_provider_example.py"}
837+
from pydantic_ai import Agent
838+
from pydantic_ai.models.openai import OpenAIModel
839+
from pydantic_ai.providers.azure import AzureProvider
840+
841+
model = OpenAIModel(
842+
'gpt-4o',
843+
provider=AzureProvider(
844+
azure_endpoint='your-azure-endpoint',
845+
api_version='your-api-version',
846+
api_key='your-api-key',
847+
),
848+
)
849+
agent = Agent(model)
850+
...
851+
```
852+
831853
### OpenRouter
832854

833855
To use [OpenRouter](https://openrouter.ai), first create an API key at [openrouter.ai/keys](https://openrouter.ai/keys).

examples/pydantic_ai_examples/question_graph.py

+37-56
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,16 @@
99

1010
from dataclasses import dataclass, field
1111
from pathlib import Path
12-
from typing import Annotated
1312

1413
import logfire
15-
from devtools import debug
16-
from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, HistoryStep
14+
from groq import BaseModel
15+
from pydantic_graph import (
16+
BaseNode,
17+
End,
18+
Graph,
19+
GraphRunContext,
20+
)
21+
from pydantic_graph.persistence.file import FileStatePersistence
1722

1823
from pydantic_ai import Agent
1924
from pydantic_ai.format_as_xml import format_as_xml
@@ -41,22 +46,23 @@ async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
4146
)
4247
ctx.state.ask_agent_messages += result.all_messages()
4348
ctx.state.question = result.data
44-
return Answer()
49+
return Answer(result.data)
4550

4651

4752
@dataclass
4853
class Answer(BaseNode[QuestionState]):
49-
answer: str | None = None
54+
question: str
5055

5156
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
52-
assert self.answer is not None
53-
return Evaluate(self.answer)
57+
answer = input(f'{self.question}: ')
58+
return Evaluate(answer)
5459

5560

56-
@dataclass
57-
class EvaluationResult:
61+
class EvaluationResult(BaseModel, use_attribute_docstrings=True):
5862
correct: bool
63+
"""Whether the answer is correct."""
5964
comment: str
65+
"""Comment on the answer, reprimand the user if the answer is wrong."""
6066

6167

6268
evaluate_agent = Agent(
@@ -67,101 +73,76 @@ class EvaluationResult:
6773

6874

6975
@dataclass
70-
class Evaluate(BaseNode[QuestionState]):
76+
class Evaluate(BaseNode[QuestionState, None, str]):
7177
answer: str
7278

7379
async def run(
7480
self,
7581
ctx: GraphRunContext[QuestionState],
76-
) -> Congratulate | Reprimand:
82+
) -> End[str] | Reprimand:
7783
assert ctx.state.question is not None
7884
result = await evaluate_agent.run(
7985
format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
8086
message_history=ctx.state.evaluate_agent_messages,
8187
)
8288
ctx.state.evaluate_agent_messages += result.all_messages()
8389
if result.data.correct:
84-
return Congratulate(result.data.comment)
90+
return End(result.data.comment)
8591
else:
8692
return Reprimand(result.data.comment)
8793

8894

89-
@dataclass
90-
class Congratulate(BaseNode[QuestionState, None, None]):
91-
comment: str
92-
93-
async def run(
94-
self, ctx: GraphRunContext[QuestionState]
95-
) -> Annotated[End, Edge(label='success')]:
96-
print(f'Correct answer! {self.comment}')
97-
return End(None)
98-
99-
10095
@dataclass
10196
class Reprimand(BaseNode[QuestionState]):
10297
comment: str
10398

10499
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
105100
print(f'Comment: {self.comment}')
106-
# > Comment: Vichy is no longer the capital of France.
107101
ctx.state.question = None
108102
return Ask()
109103

110104

111105
question_graph = Graph(
112-
nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand), state_type=QuestionState
106+
nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
113107
)
114108

115109

116110
async def run_as_continuous():
117111
state = QuestionState()
118112
node = Ask()
119-
history: list[HistoryStep[QuestionState, None]] = []
120-
with logfire.span('run questions graph'):
121-
while True:
122-
node = await question_graph.next(node, history, state=state)
123-
if isinstance(node, End):
124-
debug([e.data_snapshot() for e in history])
125-
break
126-
elif isinstance(node, Answer):
127-
assert state.question
128-
node.answer = input(f'{state.question} ')
129-
# otherwise just continue
113+
end = await question_graph.run(node, state=state)
114+
print('END:', end.output)
130115

131116

132117
async def run_as_cli(answer: str | None):
133-
history_file = Path('question_graph_history.json')
134-
history = (
135-
question_graph.load_history(history_file.read_bytes())
136-
if history_file.exists()
137-
else []
138-
)
139-
140-
if history:
141-
last = history[-1]
142-
assert last.kind == 'node', 'expected last step to be a node'
143-
state = last.state
144-
assert answer is not None, 'answer is required to continue from history'
145-
node = Answer(answer)
118+
persistence = FileStatePersistence(Path('question_graph.json'))
119+
persistence.set_graph_types(question_graph)
120+
121+
if snapshot := await persistence.load_next():
122+
state = snapshot.state
123+
assert answer is not None, (
124+
'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli <answer>"'
125+
)
126+
node = Evaluate(answer)
146127
else:
147128
state = QuestionState()
148129
node = Ask()
149-
debug(state, node)
130+
# debug(state, node)
150131

151-
with logfire.span('run questions graph'):
132+
async with question_graph.iter(node, state=state, persistence=persistence) as run:
152133
while True:
153-
node = await question_graph.next(node, history, state=state)
134+
node = await run.next()
154135
if isinstance(node, End):
155-
debug([e.data_snapshot() for e in history])
136+
print('END:', node.data)
137+
history = await persistence.load_all()
138+
print('history:', '\n'.join(str(e.node) for e in history), sep='\n')
156139
print('Finished!')
157140
break
158141
elif isinstance(node, Answer):
159-
print(state.question)
142+
print(node.question)
160143
break
161144
# otherwise just continue
162145

163-
history_file.write_bytes(question_graph.dump_history(history, indent=2))
164-
165146

166147
if __name__ == '__main__':
167148
import asyncio

examples/pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "pydantic-ai-examples"
7-
version = "0.0.39"
7+
version = "0.0.40"
88
description = "Examples of how to use PydanticAI and what it can do."
99
authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
1010
license = "MIT"
@@ -32,7 +32,7 @@ classifiers = [
3232
]
3333
requires-python = ">=3.9"
3434
dependencies = [
35-
"pydantic-ai-slim[openai,vertexai,groq,anthropic]==0.0.39",
35+
"pydantic-ai-slim[openai,vertexai,groq,anthropic]==0.0.40",
3636
"asyncpg>=0.30.0",
3737
"fastapi>=0.115.4",
3838
"logfire[asyncpg,fastapi,sqlite3]>=2.6",

mkdocs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ nav:
6666
- api/providers.md
6767
- api/pydantic_graph/graph.md
6868
- api/pydantic_graph/nodes.md
69-
- api/pydantic_graph/state.md
69+
- api/pydantic_graph/persistence.md
7070
- api/pydantic_graph/mermaid.md
7171
- api/pydantic_graph/exceptions.md
7272

pydantic_ai_slim/pydantic_ai/agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,8 @@ async def main():
475475
start_node,
476476
state=state,
477477
deps=graph_deps,
478-
infer_name=False,
479478
span=use_span(run_span, end_on_exit=True),
479+
infer_name=False,
480480
) as graph_run:
481481
yield AgentRun(graph_run)
482482

pydantic_ai_slim/pydantic_ai/models/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,14 @@ def model_name(self) -> str:
262262

263263
@property
264264
@abstractmethod
265-
def system(self) -> str | None:
266-
"""The system / model provider, ex: openai."""
265+
def system(self) -> str:
266+
"""The system / model provider, ex: openai.
267+
268+
Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
269+
so should use well-known values listed in
270+
https://opentelemetry.io/docs/specs/semconv/attributes-registry/gen-ai/#gen-ai-system
271+
when applicable.
272+
"""
267273
raise NotImplementedError()
268274

269275
@property

pydantic_ai_slim/pydantic_ai/models/anthropic.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,7 @@
3333
)
3434
from ..settings import ModelSettings
3535
from ..tools import ToolDefinition
36-
from . import (
37-
Model,
38-
ModelRequestParameters,
39-
StreamedResponse,
40-
cached_async_http_client,
41-
check_allow_model_requests,
42-
)
36+
from . import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client, check_allow_model_requests
4337

4438
try:
4539
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
@@ -115,7 +109,7 @@ class AnthropicModel(Model):
115109
client: AsyncAnthropic = field(repr=False)
116110

117111
_model_name: AnthropicModelName = field(repr=False)
118-
_system: str | None = field(default='anthropic', repr=False)
112+
_system: str = field(default='anthropic', repr=False)
119113

120114
def __init__(
121115
self,
@@ -183,7 +177,7 @@ def model_name(self) -> AnthropicModelName:
183177
return self._model_name
184178

185179
@property
186-
def system(self) -> str | None:
180+
def system(self) -> str:
187181
"""The system / model provider."""
188182
return self._system
189183

@@ -355,8 +349,17 @@ async def _map_user_prompt(
355349
source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore
356350
type='image',
357351
)
352+
elif item.media_type == 'application/pdf':
353+
yield DocumentBlockParam(
354+
source=Base64PDFSourceParam(
355+
data=io.BytesIO(item.data),
356+
media_type='application/pdf',
357+
type='base64',
358+
),
359+
type='document',
360+
)
358361
else:
359-
raise RuntimeError('Only images are supported for binary content')
362+
raise RuntimeError('Only images and PDFs are supported for binary content')
360363
elif isinstance(item, ImageUrl):
361364
try:
362365
response = await cached_async_http_client().get(item.url)

pydantic_ai_slim/pydantic_ai/models/bedrock.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ class BedrockConverseModel(Model):
119119
client: BedrockRuntimeClient
120120

121121
_model_name: BedrockModelName = field(repr=False)
122-
_system: str | None = field(default='bedrock', repr=False)
122+
_system: str = field(default='bedrock', repr=False)
123123

124124
@property
125125
def model_name(self) -> str:
126126
"""The model name."""
127127
return self._model_name
128128

129129
@property
130-
def system(self) -> str | None:
130+
def system(self) -> str:
131131
"""The system / model provider, ex: openai."""
132132
return self._system
133133

0 commit comments

Comments
 (0)