Skip to content

Add base_url to models, populate server.address and server.port in spans #1074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def system(self) -> str | None:
"""The system / model provider, ex: openai."""
raise NotImplementedError()

@property
def base_url(self) -> str | None:
"""The base URL for the provider API, if available."""
return None


@dataclass
class StreamedResponse(ABC):
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def __init__(
else:
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())

@property
def base_url(self) -> str:
return str(self.client.base_url)

async def request(
self,
messages: list[ModelMessage],
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
}
}

@property
def base_url(self) -> str:
return str(self.client.meta.endpoint_url)

async def request(
self,
messages: list[ModelMessage],
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def __init__(
else:
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client)

@property
def base_url(self) -> str:
client_wrapper = self.client._client_wrapper # type: ignore
return str(client_wrapper.get_base_url())

async def request(
self,
messages: list[ModelMessage],
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def system(self) -> str | None:
"""The system / model provider, n/a for fallback models."""
return None

@property
def base_url(self) -> str | None:
return self.models[0].base_url


def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
"""Create a default fallback condition for the given exceptions."""
Expand Down
5 changes: 3 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
else:
self._system = provider.name
self.client = provider.client
self._url = str(self.client.base_url)
else:
if api_key is None:
if env_api_key := os.getenv('GEMINI_API_KEY'):
Expand All @@ -159,7 +160,7 @@ def auth(self) -> AuthProtocol:
return self._auth

@property
def url(self) -> str:
def base_url(self) -> str:
assert self._url is not None, 'URL not initialized'
return self._url

Expand Down Expand Up @@ -257,7 +258,7 @@ async def _make_request(
'User-Agent': get_user_agent(),
}
if self._provider is None: # pragma: no cover
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
url = self.base_url + ('streamGenerateContent' if streamed else 'generateContent')
headers.update(await self.auth.headers())
else:
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def __init__(
else:
self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())

@property
def base_url(self) -> str:
return str(self.client.base_url)

async def request(
self,
messages: list[ModelMessage],
Expand Down
12 changes: 10 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Literal
from urllib.parse import urlparse

from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
Expand Down Expand Up @@ -142,15 +143,22 @@ def _instrument(
system = getattr(self.wrapped, 'system', '') or self.wrapped.__class__.__name__.removesuffix('Model').lower()
system = {'google-gla': 'gemini', 'google-vertex': 'vertex_ai', 'mistral': 'mistral_ai'}.get(system, system)
# TODO Missing attributes:
# - server.address: requires a Model.base_url abstract method or similar
# - server.port: to parse from the base_url
# - error.type: unclear if we should do something here or just always rely on span exceptions
# - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
attributes: dict[str, AttributeValue] = {
'gen_ai.operation.name': operation,
'gen_ai.system': system,
'gen_ai.request.model': model_name,
}
if base_url := self.wrapped.base_url:
try:
parsed = urlparse(base_url)
if parsed.hostname:
attributes['server.address'] = parsed.hostname
if parsed.port:
attributes['server.port'] = parsed.port
except Exception: # pragma: no cover
pass

if model_settings:
for key in MODEL_SETTING_ATTRIBUTES:
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def __init__(
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())

@property
def base_url(self) -> str:
return str(self.client.sdk_configuration.get_server_details()[0])

async def request(
self,
messages: list[ModelMessage],
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def __init__(
self.system_prompt_role = system_prompt_role
self._system = system

@property
def base_url(self) -> str:
return str(self.client.base_url)

async def request(
self,
messages: list[ModelMessage],
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_init():
assert m.client.api_key == 'foobar'
assert m.model_name == 'claude-3-5-haiku-latest'
assert m.system == 'anthropic'
assert m.base_url == 'https://api.anthropic.com'


@dataclass
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def bedrock_provider():

async def test_bedrock_model(allow_model_requests: None, bedrock_provider: BedrockProvider):
model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider)
assert model.base_url == 'https://bedrock-runtime.us-east-1.amazonaws.com'
agent = Agent(model=model, system_prompt='You are a chatbot.')

result = await agent.run('Hello!')
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_init():
m = CohereModel('command-r7b-12-2024', api_key='foobar')
assert m.model_name == 'command-r7b-12-2024'
assert m.system == 'cohere'
assert m.base_url == 'https://api.cohere.com'


@dataclass
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_init() -> None:
'FallBackModel[function:failure_response:, function:success_response:]'
)
assert fallback_model.system is None
assert fallback_model.base_url is None


def test_first_successful() -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_api_key_arg(env: TestEnv):
env.set('GEMINI_API_KEY', 'via-env-var')
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(api_key='via-arg'))
assert m.client.headers['x-goog-api-key'] == 'via-arg'
assert m.base_url == 'https://generativelanguage.googleapis.com/v1beta/models/'


def test_api_key_env_var(env: TestEnv):
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_init():
assert m.client.api_key == 'foobar'
assert m.model_name == 'llama-3.3-70b-versatile'
assert m.system == 'groq'
assert m.base_url == 'https://api.groq.com'


@dataclass
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def system(self) -> str:
def model_name(self) -> str:
return 'my_model'

@property
def base_url(self) -> str:
return 'https://example.com:8000/foo'

async def request(
self,
messages: list[ModelMessage],
Expand Down Expand Up @@ -146,6 +150,8 @@ async def test_instrumented_model(capfire: CaptureLogfire):
'gen_ai.operation.name': 'chat',
'gen_ai.system': 'my_system',
'gen_ai.request.model': 'my_model',
'server.address': 'example.com',
'server.port': 8000,
'gen_ai.request.temperature': 1,
'logfire.msg': 'chat my_model',
'logfire.span_type': 'span',
Expand Down Expand Up @@ -366,6 +372,8 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire):
'gen_ai.operation.name': 'chat',
'gen_ai.system': 'my_system',
'gen_ai.request.model': 'my_model',
'server.address': 'example.com',
'server.port': 8000,
'gen_ai.request.temperature': 1,
'logfire.msg': 'chat my_model',
'logfire.span_type': 'span',
Expand Down Expand Up @@ -447,6 +455,8 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire):
'gen_ai.operation.name': 'chat',
'gen_ai.system': 'my_system',
'gen_ai.request.model': 'my_model',
'server.address': 'example.com',
'server.port': 8000,
'gen_ai.request.temperature': 1,
'logfire.msg': 'chat my_model',
'logfire.span_type': 'span',
Expand Down Expand Up @@ -547,6 +557,8 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire):
'gen_ai.operation.name': 'chat',
'gen_ai.system': 'my_system',
'gen_ai.request.model': 'my_model',
'server.address': 'example.com',
'server.port': 8000,
'gen_ai.request.temperature': 1,
'logfire.msg': 'chat my_model',
'logfire.span_type': 'span',
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def func_chunk(
def test_init():
m = MistralModel('mistral-large-latest', api_key='foobar')
assert m.model_name == 'mistral-large-latest'
assert m.base_url == 'https://api.mistral.ai'


#####################
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

def test_init():
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key='foobar'))
assert str(m.client.base_url) == 'https://api.openai.com/v1/'
assert m.base_url == 'https://api.openai.com/v1/'
assert m.client.api_key == 'foobar'
assert m.model_name == 'gpt-4o'

Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):

await model.ainit()

assert model.url == snapshot(
assert model.base_url == snapshot(
'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
'publishers/google/models/gemini-1.5-flash:'
)
Expand Down Expand Up @@ -66,7 +66,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):

assert patch.call_count == 1

assert model.url == snapshot(
assert model.base_url == snapshot(
'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
'publishers/google/models/gemini-1.5-flash:'
)
Expand All @@ -75,7 +75,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
assert model.system == snapshot('google-vertex')

await model.ainit()
assert model.url is not None
assert model.base_url is not None
assert model.auth is not None
assert patch.call_count == 1

Expand All @@ -90,7 +90,7 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)

await model.ainit()

assert model.url == snapshot(
assert model.base_url == snapshot(
'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
'publishers/google/models/gemini-1.5-flash:'
)
Expand Down