Skip to content

fix(telemetry): fix agent span start and end when using Agent.stream_async() #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 50 additions & 16 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None

self.trace_span = self.tracer.start_agent_span(
prompt=prompt,
model_id=model_id,
tools=self.tool_names,
system_prompt=self.system_prompt,
custom_trace_attributes=self.trace_attributes,
)
self._start_agent_trace_span(prompt)

try:
# Run the event loop and get the result
result = self._run_loop(prompt, kwargs)

if self.trace_span:
self.tracer.end_agent_span(span=self.trace_span, response=result)
self._end_agent_trace_span(response=result)

return result
except Exception as e:
if self.trace_span:
self.tracer.end_agent_span(span=self.trace_span, error=e)
self._end_agent_trace_span(error=e)

# Re-raise the exception to preserve original behavior
raise
Expand Down Expand Up @@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
yield event["data"]
```
"""
self._start_agent_trace_span(prompt)

_stop_event = uuid4()

queue = asyncio.Queue[Any]()
Expand All @@ -400,8 +392,10 @@ def target_callback() -> None:
nonlocal kwargs

try:
self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
except BaseException as e:
result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
self._end_agent_trace_span(response=result)
except Exception as e:
self._end_agent_trace_span(error=e)
enqueue(e)
finally:
enqueue(_stop_event)
Expand All @@ -414,7 +408,7 @@ def target_callback() -> None:
item = await queue.get()
if item == _stop_event:
break
if isinstance(item, BaseException):
if isinstance(item, Exception):
raise item
yield item
finally:
Expand Down Expand Up @@ -546,3 +540,43 @@ def _record_tool_execution(
messages.append(tool_use_msg)
messages.append(tool_result_msg)
messages.append(assistant_msg)

def _start_agent_trace_span(self, prompt: str) -> None:
"""Starts a trace span for the agent.

Args:
prompt: The natural language prompt from the user.
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None

self.trace_span = self.tracer.start_agent_span(
prompt=prompt,
model_id=model_id,
tools=self.tool_names,
system_prompt=self.system_prompt,
custom_trace_attributes=self.trace_attributes,
)

def _end_agent_trace_span(
self,
response: Optional[AgentResult] = None,
error: Optional[Exception] = None,
) -> None:
"""Ends a trace span for the agent.

Args:
span: The span to end.
response: Response to record as a trace attribute.
error: Error to record as a trace attribute.
"""
if self.trace_span:
trace_attributes: Dict[str, Any] = {
"span": self.trace_span,
}

if response:
trace_attributes["response"] = response
if error:
trace_attributes["error"] = error

self.tracer.end_agent_span(**trace_attributes)
87 changes: 84 additions & 3 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import pytest

import strands
from strands.agent.agent import Agent
from strands import Agent
from strands.agent import AgentResult
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
Expand Down Expand Up @@ -687,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler():

@pytest.mark.asyncio
async def test_stream_async_returns_all_events(mock_event_loop_cycle):
mock_event_loop_cycle.side_effect = ValueError("Test exception")

agent = Agent()

# Define the side effect to simulate callback handler being called multiple times
Expand Down Expand Up @@ -952,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result)


@pytest.mark.asyncio
@unittest.mock.patch("strands.agent.agent.get_tracer")
async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle):
"""Test that stream_async creates and ends a span when the call succeeds."""
# Setup mock tracer and span
mock_tracer = unittest.mock.MagicMock()
mock_span = unittest.mock.MagicMock()
mock_tracer.start_agent_span.return_value = mock_span
mock_get_tracer.return_value = mock_tracer

# Define the side effect to simulate callback handler being called multiple times
def call_callback_handler(*args, **kwargs):
# Extract the callback handler from kwargs
callback_handler = kwargs.get("callback_handler")
# Call the callback handler with different data values
callback_handler(data="First chunk")
callback_handler(data="Second chunk")
callback_handler(data="Final chunk", complete=True)
# Return expected values from event_loop_cycle
return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {}

mock_event_loop_cycle.side_effect = call_callback_handler

# Create agent and make a call
agent = Agent(model=mock_model)
iterator = agent.stream_async("test prompt")
async for _event in iterator:
pass # NoOp

# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
prompt="test prompt",
model_id=unittest.mock.ANY,
tools=agent.tool_names,
system_prompt=agent.system_prompt,
custom_trace_attributes=agent.trace_attributes,
)

expected_response = AgentResult(
stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={}
)

# Verify span was ended with the result
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response)


@unittest.mock.patch("strands.agent.agent.get_tracer")
def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
"""Test that __call__ creates and ends a span when an exception occurs."""
Expand Down Expand Up @@ -985,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


@pytest.mark.asyncio
@unittest.mock.patch("strands.agent.agent.get_tracer")
async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
"""Test that stream_async creates and ends a span when the call succeeds."""
# Setup mock tracer and span
mock_tracer = unittest.mock.MagicMock()
mock_span = unittest.mock.MagicMock()
mock_tracer.start_agent_span.return_value = mock_span
mock_get_tracer.return_value = mock_tracer

# Define the side effect to simulate callback handler raising an Exception
test_exception = ValueError("Test exception")
mock_model.mock_converse.side_effect = test_exception

# Create agent and make a call
agent = Agent(model=mock_model)

# Call the agent and catch the exception
with pytest.raises(ValueError):
iterator = agent.stream_async("test prompt")
async for _event in iterator:
pass # NoOp

# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
prompt="test prompt",
model_id=unittest.mock.ANY,
tools=agent.tool_names,
system_prompt=agent.system_prompt,
custom_trace_attributes=agent.trace_attributes,
)

# Verify span was ended with the exception
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)


@unittest.mock.patch("strands.agent.agent.get_tracer")
def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model):
"""Test that event_loop_cycle is called with the parent span."""
Expand Down