Skip to content

InstrumentedModel and FallbackModel fixes #1121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,14 @@ def model_name(self) -> str:

@property
@abstractmethod
def system(self) -> str | None:
"""The system / model provider, ex: openai."""
def system(self) -> str:
"""The system / model provider, ex: openai.

Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
so should use well-known values listed in
https://opentelemetry.io/docs/specs/semconv/attributes-registry/gen-ai/#gen-ai-system
when applicable.
"""
raise NotImplementedError()

@property
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class AnthropicModel(Model):
client: AsyncAnthropic = field(repr=False)

_model_name: AnthropicModelName = field(repr=False)
_system: str | None = field(default='anthropic', repr=False)
_system: str = field(default='anthropic', repr=False)

def __init__(
self,
Expand Down Expand Up @@ -183,7 +183,7 @@ def model_name(self) -> AnthropicModelName:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ class BedrockConverseModel(Model):
client: BedrockRuntimeClient

_model_name: BedrockModelName = field(repr=False)
_system: str | None = field(default='bedrock', repr=False)
_system: str = field(default='bedrock', repr=False)

@property
def model_name(self) -> str:
"""The model name."""
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider, ex: openai."""
return self._system

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class CohereModel(Model):
client: AsyncClientV2 = field(repr=False)

_model_name: CohereModelName = field(repr=False)
_system: str | None = field(default='cohere', repr=False)
_system: str = field(default='cohere', repr=False)

def __init__(
self,
Expand Down Expand Up @@ -148,7 +148,7 @@ def model_name(self) -> CohereModelName:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
24 changes: 16 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations as _annotations

from collections.abc import AsyncIterator
from contextlib import AsyncExitStack, asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable

from opentelemetry.trace import get_current_span

from pydantic_ai.models.instrumented import InstrumentedModel

from ..exceptions import FallbackExceptionGroup, ModelHTTPError
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model

Expand Down Expand Up @@ -40,7 +44,6 @@ def __init__(
fallback_on: A callable or tuple of exceptions that should trigger a fallback.
"""
self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'

if isinstance(fallback_on, tuple):
self._fallback_on = _default_fallback_condition_factory(fallback_on)
Expand All @@ -62,13 +65,19 @@ async def request(
for model in self.models:
try:
response, usage = await model.request(messages, model_settings, model_request_parameters)
response.model_used = model # type: ignore
return response, usage
except Exception as exc:
if self._fallback_on(exc):
exceptions.append(exc)
continue
raise exc
else:
with suppress(Exception):
span = get_current_span()
if span.is_recording():
attributes = getattr(span, 'attributes', {})
if attributes.get('gen_ai.request.model') == self.model_name:
span.set_attributes(InstrumentedModel.model_attributes(model))
return response, usage

raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)

Expand Down Expand Up @@ -101,12 +110,11 @@ async def request_stream(
@property
def model_name(self) -> str:
"""The model name."""
return self._model_name
return f'fallback:{",".join(model.model_name for model in self.models)}'

@property
def system(self) -> str | None:
"""The system / model provider, n/a for fallback models."""
return None
def system(self) -> str:
return f'fallback:{",".join(model.system for model in self.models)}'

@property
def base_url(self) -> str | None:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class FunctionModel(Model):
stream_function: StreamFunctionDef | None = None

_model_name: str = field(repr=False)
_system: str | None = field(default=None, repr=False)
_system: str = field(default='function', repr=False)

@overload
def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
Expand Down Expand Up @@ -140,7 +140,7 @@ def model_name(self) -> str:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class GeminiModel(Model):
_provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
_auth: AuthProtocol | None = field(repr=False)
_url: str | None = field(repr=False)
_system: str | None = field(default='google-gla', repr=False)
_system: str = field(default='gemini', repr=False)

@overload
def __init__(
Expand Down Expand Up @@ -197,7 +197,7 @@ def model_name(self) -> GeminiModelName:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class GroqModel(Model):
client: AsyncGroq = field(repr=False)

_model_name: GroqModelName = field(repr=False)
_system: str | None = field(default='groq', repr=False)
_system: str = field(default='groq', repr=False)

@overload
def __init__(
Expand Down Expand Up @@ -186,7 +186,7 @@ def model_name(self) -> GroqModelName:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
10 changes: 2 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,7 @@ def finish(response: ModelResponse, usage: Usage):
)
)
new_attributes: dict[str, AttributeValue] = usage.opentelemetry_attributes() # type: ignore
if model_used := getattr(response, 'model_used', None):
# FallbackModel sets model_used on the response so that we can report the attributes
# of the model that was actually used.
new_attributes.update(self.model_attributes(model_used))
attributes.update(new_attributes)
attributes.update(getattr(span, 'attributes', {}))
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
new_attributes['gen_ai.response.model'] = response.model_name or request_model
span.set_attributes(new_attributes)
Expand Down Expand Up @@ -213,10 +209,8 @@ def _emit_events(self, span: Span, events: list[Event]) -> None:

@staticmethod
def model_attributes(model: Model):
system = getattr(model, 'system', '') or model.__class__.__name__.removesuffix('Model').lower()
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
attributes: dict[str, AttributeValue] = {
GEN_AI_SYSTEM_ATTRIBUTE: system,
GEN_AI_SYSTEM_ATTRIBUTE: model.system,
GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
}
if base_url := model.base_url:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class MistralModel(Model):
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""

_model_name: MistralModelName = field(repr=False)
_system: str | None = field(default='mistral', repr=False)
_system: str = field(default='mistral_ai', repr=False)

def __init__(
self,
Expand Down Expand Up @@ -179,7 +179,7 @@ def model_name(self) -> MistralModelName:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class OpenAIModel(Model):
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)

_model_name: OpenAIModelName = field(repr=False)
_system: str | None = field(repr=False)
_system: str = field(repr=False)

@overload
def __init__(
Expand All @@ -108,7 +108,7 @@ def __init__(
*,
provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
system_prompt_role: OpenAISystemPromptRole | None = None,
system: str | None = 'openai',
system: str = 'openai',
) -> None: ...

@deprecated('Use the `provider` parameter instead of `base_url`, `api_key`, `openai_client` and `http_client`.')
Expand All @@ -123,7 +123,7 @@ def __init__(
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
system_prompt_role: OpenAISystemPromptRole | None = None,
system: str | None = 'openai',
system: str = 'openai',
) -> None: ...

def __init__(
Expand All @@ -136,7 +136,7 @@ def __init__(
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
system_prompt_role: OpenAISystemPromptRole | None = None,
system: str | None = 'openai',
system: str = 'openai',
):
"""Initialize an OpenAI model.

Expand Down Expand Up @@ -224,7 +224,7 @@ def model_name(self) -> OpenAIModelName:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class TestModel(Model):
This is set when a request is made, so will reflect the function tools from the last step of the last run.
"""
_model_name: str = field(default='test', repr=False)
_system: str | None = field(default=None, repr=False)
_system: str = field(default='test', repr=False)

async def request(
self,
Expand Down Expand Up @@ -113,7 +113,7 @@ def model_name(self) -> str:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class VertexAIModel(GeminiModel):
url_template: str

_model_name: GeminiModelName = field(repr=False)
_system: str | None = field(default='google-vertex', repr=False)
_system: str = field(default='vertex_ai', repr=False)

# TODO __init__ can be removed once we drop 3.9 and we can set kw_only correctly on the dataclass
def __init__(
Expand Down Expand Up @@ -175,7 +175,7 @@ def model_name(self) -> GeminiModelName:
return self._model_name

@property
def system(self) -> str | None:
def system(self) -> str:
"""The system / model provider."""
return self._system

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def model_name(self) -> str:
return self.wrapped.model_name

@property
def system(self) -> str | None:
def system(self) -> str:
return self.wrapped.system

def __getattr__(self, item: str):
Expand Down
11 changes: 0 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ def IsNow(*args: Any, **kwargs: Any):
return _IsNow(*args, **kwargs)


try:
from logfire.testing import CaptureLogfire
except ImportError:
pass
else:

@pytest.fixture(autouse=True)
def logfire_disable(capfire: CaptureLogfire):
pass


class TestEnv:
__test__ = False

Expand Down
10 changes: 4 additions & 6 deletions tests/models/test_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ def failure_response(_model_messages: list[ModelMessage], _agent_info: AgentInfo

def test_init() -> None:
fallback_model = FallbackModel(failure_model, success_model)
assert fallback_model.model_name == snapshot(
'FallBackModel[function:failure_response:, function:success_response:]'
)
assert fallback_model.system is None
assert fallback_model.model_name == snapshot('fallback:function:failure_response:,function:success_response:')
assert fallback_model.system == 'fallback:function,function'
assert fallback_model.base_url is None


Expand Down Expand Up @@ -139,7 +137,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None:
'attributes': {
'gen_ai.operation.name': 'chat',
'logfire.span_type': 'span',
'logfire.msg': 'chat FallBackModel[function:failure_response:, function:success_response:]',
'logfire.msg': 'chat fallback:function:failure_response:,function:success_response:',
'gen_ai.usage.input_tokens': 51,
'gen_ai.usage.output_tokens': 1,
'gen_ai.system': 'function',
Expand Down Expand Up @@ -172,7 +170,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None:
'start_time': 1000000000,
'end_time': 6000000000,
'attributes': {
'model_name': 'FallBackModel[function:failure_response:, function:success_response:]',
'model_name': 'fallback:function:failure_response:,function:success_response:',
'agent_name': 'agent',
'logfire.msg': 'agent run',
'logfire.span_type': 'span',
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
'MISTRAL_API_KEY',
'mistral:mistral-small-latest',
'mistral-small-latest',
'mistral',
'mistral_ai',
'mistral',
'MistralModel',
),
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
)
assert model.auth is not None
assert model.model_name == snapshot('gemini-1.5-flash')
assert model.system == snapshot('google-vertex')
assert model.system == snapshot('vertex_ai')


class NoOpCredentials:
Expand Down Expand Up @@ -72,7 +72,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
)
assert model.auth is not None
assert model.model_name == snapshot('gemini-1.5-flash')
assert model.system == snapshot('google-vertex')
assert model.system == snapshot('vertex_ai')

await model.ainit()
assert model.base_url is not None
Expand Down