Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions py/src/braintrust/wrappers/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ async def agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any)
result = await wrapped(*args, **kwargs)
end_time = time.time()

_create_tool_spans_from_messages(result)

output = _serialize_result_output(result)
metrics = _extract_usage_metrics(result, start_time, end_time)

Expand All @@ -104,6 +106,8 @@ def agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
result = wrapped(*args, **kwargs)
end_time = time.time()

_create_tool_spans_from_messages(result)

output = _serialize_result_output(result)
metrics = _extract_usage_metrics(result, start_time, end_time)

Expand Down Expand Up @@ -184,6 +188,9 @@ async def agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: Any

end_time = time.time()

if final_result:
_create_tool_spans_from_messages(final_result)

output = None
metrics = {
"start": start_time,
Expand Down Expand Up @@ -487,6 +494,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.span_cm and self.start_time and self.stream_result:
end_time = time.time()

_create_tool_spans_from_messages(self.stream_result)

output = _serialize_stream_output(self.stream_result)
metrics = _extract_stream_usage_metrics(
self.stream_result, self.start_time, end_time, self._first_token_time
Expand Down Expand Up @@ -668,6 +677,9 @@ def _finalize(self):
if self._span and not self._logged and self._stream_result:
try:
end_time = time.time()

_create_tool_spans_from_messages(self._stream_result)

output = _serialize_stream_output(self._stream_result)
metrics = _extract_stream_usage_metrics(
self._stream_result, self._start_time, end_time, self._first_token_time
Expand Down Expand Up @@ -771,6 +783,89 @@ def __next__(self):
return item


def _create_tool_spans_from_messages(result: Any) -> None:
"""
Create TOOL-type spans from tool call/return message parts in a completed agent result.

Uses message timestamps from PydanticAI to position spans correctly in the trace:
- start_time = ModelResponse.timestamp (when the model requested the tool call)
- end_time = ModelRequest.timestamp (when the tool result was sent back)
"""
try:
_create_tool_spans_from_messages_impl(result)
except Exception:
pass


def _create_tool_spans_from_messages_impl(result: Any) -> None:
from pydantic_ai.messages import ToolCallPart, ToolReturnPart

messages = result.new_messages()

returns_by_id: dict[str, tuple[Any, float | None]] = {}
for msg in messages:
if not hasattr(msg, "parts"):
continue
msg_ts = _msg_timestamp(msg)
for part in msg.parts:
if isinstance(part, ToolReturnPart) and hasattr(part, "tool_call_id"):
returns_by_id[part.tool_call_id] = (part, msg_ts)

for msg in messages:
if not hasattr(msg, "parts"):
continue
call_ts = _msg_timestamp(msg)
for part in msg.parts:
if not isinstance(part, ToolCallPart):
continue

tool_name = getattr(part, "tool_name", None) or "unknown_tool"
tool_call_id = getattr(part, "tool_call_id", None)

try:
input_data = part.args_as_dict()
except Exception:
input_data = bt_safe_deep_copy(getattr(part, "args", None))

output_data = None
return_ts: float | None = None
if tool_call_id and tool_call_id in returns_by_id:
return_part, return_ts = returns_by_id[tool_call_id]
output_data = bt_safe_deep_copy(getattr(return_part, "content", None))

metadata = {}
if tool_call_id:
metadata["tool_call_id"] = tool_call_id

with start_span(
name=tool_name,
type=SpanTypeAttribute.TOOL,
input=input_data,
start_time=call_ts,
metadata=metadata if metadata else None,
) as tool_span:
metrics = {}
if call_ts is not None:
metrics["start"] = call_ts
if return_ts is not None:
metrics["end"] = return_ts
if call_ts is not None and return_ts is not None:
metrics["duration"] = return_ts - call_ts
tool_span.log(output=output_data, metrics=metrics if metrics else None)
tool_span.end(end_time=return_ts)


def _msg_timestamp(msg: Any) -> float | None:
"""Extract epoch-seconds timestamp from a PydanticAI message, or None."""
ts = getattr(msg, "timestamp", None)
if ts is None:
return None
try:
return ts.timestamp() # datetime → float
except Exception:
return None


def _serialize_user_prompt(user_prompt: Any) -> Any:
"""Serialize user prompt, handling BinaryContent and other types."""
if user_prompt is None:
Expand Down
40 changes: 29 additions & 11 deletions py/src/braintrust/wrappers/test_pydantic_ai_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,15 @@ def get_weather(city: str) -> str:
assert "weather" in str(agent_span["input"]).lower() or "paris" in str(agent_span["input"]).lower()
_assert_metrics_are_valid(agent_span["metrics"], start, end)

tool_spans = [s for s in spans if s["span_attributes"].get("type") == SpanTypeAttribute.TOOL]
assert len(tool_spans) >= 1, f"Expected at least 1 TOOL span, got {len(tool_spans)}"

weather_tool_span = next((s for s in tool_spans if s["span_attributes"]["name"] == "get_weather"), None)
assert weather_tool_span is not None, "get_weather TOOL span not found"
assert "Paris" in str(weather_tool_span["input"]) or "paris" in str(weather_tool_span["input"]).lower()
assert "sunny" in str(weather_tool_span["output"]).lower()
assert weather_tool_span["span_parents"] == [agent_span["span_id"]], "tool span should be nested under agent_run"


@pytest.mark.vcr
@pytest.mark.asyncio
Expand Down Expand Up @@ -1796,17 +1805,19 @@ def calculate(operation: str, a: float, b: float) -> str:
# Verify toolsets are NOT in metadata (following the principle: agent.run() accepts it)
assert "toolsets" not in agent_span["metadata"], "toolsets should NOT be in metadata"

tool_spans = [s for s in spans if s["span_attributes"].get("type") == SpanTypeAttribute.TOOL]
assert len(tool_spans) >= 1, f"Expected at least 1 TOOL span, got {len(tool_spans)}"

@pytest.mark.vcr
def test_tool_execution_creates_spans(memory_logger):
"""Test that executing tools with agents works and creates traced spans.
calc_tool_span = next((s for s in tool_spans if s["span_attributes"]["name"] == "calculate"), None)
assert calc_tool_span is not None, "calculate TOOL span not found"
assert calc_tool_span["input"] is not None, "tool span should have input"
assert calc_tool_span["output"] is not None, "tool span should have output"
assert calc_tool_span["span_parents"] == [agent_span["span_id"]], "tool span should be nested under agent_run"

Note: Tool-level span creation is not yet implemented in the wrapper.
This test verifies that agents with tools work correctly and produce agent/chat spans.

Future enhancement: Add automatic span creation for tool executions as children of
the chat span that requested them.
"""
@pytest.mark.vcr
def test_tool_execution_creates_spans(memory_logger):
"""Test that executing tools with agents works and creates traced spans."""
assert not memory_logger.pop()

start = time.time()
Expand Down Expand Up @@ -1853,9 +1864,16 @@ def calculate(operation: str, a: float, b: float) -> float:
tool_names = [t["name"] for t in tools if isinstance(t, dict)]
assert "calculate" in tool_names, f"calculate tool should be in toolset, got: {tool_names}"

# TODO: Future enhancement - verify tool execution spans are created
# tool_spans = [s for s in spans if "calculate" in s["span_attributes"].get("name", "")]
# assert len(tool_spans) > 0, "Tool execution should create spans"
tool_spans = [s for s in spans if s["span_attributes"].get("type") == SpanTypeAttribute.TOOL]
assert len(tool_spans) >= 1, f"Expected at least 1 TOOL span, got {len(tool_spans)}"

calc_tool_span = next((s for s in tool_spans if s["span_attributes"]["name"] == "calculate"), None)
assert calc_tool_span is not None, "calculate TOOL span not found"
assert calc_tool_span["input"] is not None, "tool span should have input"
assert calc_tool_span["output"] is not None, "tool span should have output"
# Verify tool span is nested within the agent span tree
all_span_ids = {s["span_id"] for s in spans}
assert calc_tool_span["span_parents"][0] in all_span_ids, "tool span should be nested under a span in the agent tree"


def test_agent_tool_metadata_extraction(memory_logger):
Expand Down
Loading