Skip to content

Fix agent default callback handler #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
logger = logging.getLogger(__name__)


# Sentinel class and object to distinguish between explicit None and default parameter value
class _DefaultCallbackHandlerSentinel:
"""Sentinel class to distinguish between explicit None and default parameter value."""

pass


_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel()


class Agent:
"""Core Agent interface.

Expand All @@ -70,7 +80,7 @@ def __init__(self, agent: "Agent") -> None:
# agent tools and thus break their execution.
self._agent = agent

def __getattr__(self, name: str) -> Callable:
def __getattr__(self, name: str) -> Callable[..., Any]:
"""Call tool as a function.

This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
Expand Down Expand Up @@ -177,7 +187,9 @@ def __init__(
messages: Optional[Messages] = None,
tools: Optional[List[Union[str, Dict[str, str], Any]]] = None,
system_prompt: Optional[str] = None,
callback_handler: Optional[Callable] = PrintingCallbackHandler(),
callback_handler: Optional[
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
] = _DEFAULT_CALLBACK_HANDLER,
conversation_manager: Optional[ConversationManager] = None,
max_parallel_tools: int = os.cpu_count() or 1,
record_direct_tool_call: bool = True,
Expand All @@ -204,7 +216,8 @@ def __init__(
system_prompt: System prompt to guide model behavior.
If None, the model will behave according to its default settings.
callback_handler: Callback for processing events as they happen during agent execution.
Defaults to strands.handlers.PrintingCallbackHandler if None.
If not provided (using the default), a new PrintingCallbackHandler instance is created.
If explicitly set to None, null_callback_handler is used.
conversation_manager: Manager for conversation history and context window.
Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None.
max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls.
Expand All @@ -222,7 +235,17 @@ def __init__(
self.messages = messages if messages is not None else []

self.system_prompt = system_prompt
self.callback_handler = callback_handler or null_callback_handler

# If not provided, create a new PrintingCallbackHandler instance
# If explicitly set to None, use null_callback_handler
# Otherwise use the passed callback_handler
self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler]
if isinstance(callback_handler, _DefaultCallbackHandlerSentinel):
self.callback_handler = PrintingCallbackHandler()
elif callback_handler is None:
self.callback_handler = null_callback_handler
else:
self.callback_handler = callback_handler

self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager()

Expand Down Expand Up @@ -415,7 +438,7 @@ def target_callback() -> None:
thread.join()

def _run_loop(
self, prompt: str, kwargs: Any, supplementary_callback_handler: Optional[Callable] = None
self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None
) -> AgentResult:
"""Execute the agent's event loop with the given prompt and parameters."""
try:
Expand All @@ -441,7 +464,7 @@ def _run_loop(
finally:
self.conversation_manager.apply_management(self)

def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult:
def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
Expand Down
31 changes: 31 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,37 @@ def test_agent_with_callback_handler_none_uses_null_handler():
assert agent.callback_handler == null_callback_handler


def test_agent_callback_handler_not_provided_creates_new_instances():
"""Test that when callback_handler is not provided, new PrintingCallbackHandler instances are created."""
# Create two agents without providing callback_handler
agent1 = Agent()
agent2 = Agent()

# Both should have PrintingCallbackHandler instances
assert isinstance(agent1.callback_handler, PrintingCallbackHandler)
assert isinstance(agent2.callback_handler, PrintingCallbackHandler)

# But they should be different object instances
assert agent1.callback_handler is not agent2.callback_handler


def test_agent_callback_handler_explicit_none_uses_null_handler():
"""Test that when callback_handler is explicitly set to None, null_callback_handler is used."""
agent = Agent(callback_handler=None)

# Should use null_callback_handler
assert agent.callback_handler is null_callback_handler


def test_agent_callback_handler_custom_handler_used():
"""Test that when a custom callback_handler is provided, it is used."""
custom_handler = unittest.mock.Mock()
agent = Agent(callback_handler=custom_handler)

# Should use the provided custom handler
assert agent.callback_handler is custom_handler


@pytest.mark.asyncio
async def test_stream_async_returns_all_events(mock_event_loop_cycle):
agent = Agent()
Expand Down