Skip to content

Commit 36cb9c3

Browse files
pgrayyjsamuel1
authored andcommitted
stop passing around callback handler (strands-agents#323)
1 parent 1e8a23e commit 36cb9c3

File tree

8 files changed

+13
-84
lines changed

8 files changed

+13
-84
lines changed

src/strands/agent/agent.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def caller(
135135
system_prompt=self._agent.system_prompt,
136136
messages=self._agent.messages,
137137
tool_config=self._agent.tool_config,
138-
callback_handler=self._agent.callback_handler,
139138
kwargs=kwargs,
140139
)
141140

@@ -415,7 +414,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
415414
self._start_agent_trace_span(prompt)
416415

417416
try:
418-
events = self._run_loop(callback_handler, prompt, kwargs)
417+
events = self._run_loop(prompt, kwargs)
419418
for event in events:
420419
if "callback" in event:
421420
callback_handler(**event["callback"])
@@ -497,7 +496,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
497496
self._start_agent_trace_span(prompt)
498497

499498
try:
500-
events = self._run_loop(callback_handler, prompt, kwargs)
499+
events = self._run_loop(prompt, kwargs)
501500
for event in events:
502501
if "callback" in event:
503502
callback_handler(**event["callback"])
@@ -512,9 +511,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
512511
self._end_agent_trace_span(error=e)
513512
raise
514513

515-
def _run_loop(
516-
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
517-
) -> Generator[dict[str, Any], None, None]:
514+
def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
518515
"""Execute the agent's event loop with the given prompt and parameters."""
519516
try:
520517
# Extract key parameters
@@ -526,14 +523,12 @@ def _run_loop(
526523
self.messages.append(new_message)
527524

528525
# Execute the event loop cycle with retry logic for context limits
529-
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
526+
yield from self._execute_event_loop_cycle(kwargs)
530527

531528
finally:
532529
self.conversation_manager.apply_management(self)
533530

534-
def _execute_event_loop_cycle(
535-
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
536-
) -> Generator[dict[str, Any], None, None]:
531+
def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
537532
"""Execute the event loop cycle with retry logic for context window limits.
538533
539534
This internal method handles the execution of the event loop cycle and implements
@@ -576,7 +571,6 @@ def _execute_event_loop_cycle(
576571
system_prompt=self.system_prompt,
577572
messages=self.messages, # will be modified by event_loop_cycle
578573
tool_config=self.tool_config,
579-
callback_handler=callback_handler,
580574
tool_handler=self.tool_handler,
581575
tool_execution_handler=self.thread_pool_wrapper,
582576
event_loop_metrics=self.event_loop_metrics,
@@ -587,7 +581,7 @@ def _execute_event_loop_cycle(
587581
except ContextWindowOverflowException as e:
588582
# Try reducing the context size and retrying
589583
self.conversation_manager.reduce_context(self, e=e)
590-
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
584+
yield from self._execute_event_loop_cycle(kwargs)
591585

592586
def _record_tool_execution(
593587
self,

src/strands/event_loop/event_loop.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
import uuid
1414
from functools import partial
15-
from typing import Any, Callable, Generator, Optional, cast
15+
from typing import Any, Generator, Optional, cast
1616

1717
from opentelemetry import trace
1818

@@ -40,7 +40,6 @@ def event_loop_cycle(
4040
system_prompt: Optional[str],
4141
messages: Messages,
4242
tool_config: Optional[ToolConfig],
43-
callback_handler: Callable[..., Any],
4443
tool_handler: Optional[ToolHandler],
4544
tool_execution_handler: Optional[ParallelToolExecutorInterface],
4645
event_loop_metrics: EventLoopMetrics,
@@ -65,7 +64,6 @@ def event_loop_cycle(
6564
system_prompt: System prompt instructions for the model.
6665
messages: Conversation history messages.
6766
tool_config: Configuration for available tools.
68-
callback_handler: Callback for processing events as they happen.
6967
tool_handler: Handler for executing tools.
7068
tool_execution_handler: Optional handler for parallel tool execution.
7169
event_loop_metrics: Metrics tracking object for the event loop.
@@ -212,7 +210,6 @@ def event_loop_cycle(
212210
messages,
213211
tool_config,
214212
tool_handler,
215-
callback_handler,
216213
tool_execution_handler,
217214
event_loop_metrics,
218215
event_loop_parent_span,
@@ -258,7 +255,6 @@ def recurse_event_loop(
258255
system_prompt: Optional[str],
259256
messages: Messages,
260257
tool_config: Optional[ToolConfig],
261-
callback_handler: Callable[..., Any],
262258
tool_handler: Optional[ToolHandler],
263259
tool_execution_handler: Optional[ParallelToolExecutorInterface],
264260
event_loop_metrics: EventLoopMetrics,
@@ -274,7 +270,6 @@ def recurse_event_loop(
274270
system_prompt: System prompt instructions for the model
275271
messages: Conversation history messages
276272
tool_config: Configuration for available tools
277-
callback_handler: Callback for processing events as they happen
278273
tool_handler: Handler for tool execution
279274
tool_execution_handler: Optional handler for parallel tool execution.
280275
event_loop_metrics: Metrics tracking object for the event loop.
@@ -302,7 +297,6 @@ def recurse_event_loop(
302297
system_prompt=system_prompt,
303298
messages=messages,
304299
tool_config=tool_config,
305-
callback_handler=callback_handler,
306300
tool_handler=tool_handler,
307301
tool_execution_handler=tool_execution_handler,
308302
event_loop_metrics=event_loop_metrics,
@@ -321,7 +315,6 @@ def _handle_tool_execution(
321315
messages: Messages,
322316
tool_config: ToolConfig,
323317
tool_handler: ToolHandler,
324-
callback_handler: Callable[..., Any],
325318
tool_execution_handler: Optional[ParallelToolExecutorInterface],
326319
event_loop_metrics: EventLoopMetrics,
327320
event_loop_parent_span: Optional[trace.Span],
@@ -345,7 +338,6 @@ def _handle_tool_execution(
345338
messages (Messages): The conversation history messages.
346339
tool_config (ToolConfig): Configuration for available tools.
347340
tool_handler (ToolHandler): Handler for tool execution.
348-
callback_handler (Callable[..., Any]): Callback for processing events as they happen.
349341
tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
350342
event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
351343
event_loop_parent_span (Any): Span for the parent of this event loop.
@@ -374,7 +366,6 @@ def _handle_tool_execution(
374366
system_prompt=system_prompt,
375367
messages=messages,
376368
tool_config=tool_config,
377-
callback_handler=callback_handler,
378369
kwargs=kwargs,
379370
)
380371

@@ -415,7 +406,6 @@ def _handle_tool_execution(
415406
system_prompt=system_prompt,
416407
messages=messages,
417408
tool_config=tool_config,
418-
callback_handler=callback_handler,
419409
tool_handler=tool_handler,
420410
tool_execution_handler=tool_execution_handler,
421411
event_loop_metrics=event_loop_metrics,

src/strands/handlers/tool_handler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def process(
3434
system_prompt: Optional[str],
3535
messages: Messages,
3636
tool_config: ToolConfig,
37-
callback_handler: Any,
3837
kwargs: dict[str, Any],
3938
) -> Any:
4039
"""Process a tool invocation.
@@ -47,7 +46,6 @@ def process(
4746
system_prompt: The system prompt for the agent.
4847
messages: The conversation history.
4948
tool_config: Configuration for the tool.
50-
callback_handler: Callback for processing events as they happen.
5149
kwargs: Additional keyword arguments passed to the tool.
5250
5351
Returns:
@@ -81,7 +79,6 @@ def process(
8179
"system_prompt": system_prompt,
8280
"messages": messages,
8381
"tool_config": tool_config,
84-
"callback_handler": callback_handler,
8582
}
8683
)
8784

src/strands/models/mistral.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import base64
77
import json
88
import logging
9-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union
9+
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union
1010

1111
from mistralai import Mistral
1212
from pydantic import BaseModel
@@ -471,14 +471,15 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
471471

472472
@override
473473
def structured_output(
474-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
474+
self,
475+
output_model: Type[T],
476+
prompt: Messages,
475477
) -> Generator[dict[str, Union[T, Any]], None, None]:
476478
"""Get structured output from the model.
477479
478480
Args:
479481
output_model: The output model to use for the agent.
480482
prompt: The prompt messages to use for the agent.
481-
callback_handler: Optional callback handler for processing events.
482483
483484
Returns:
484485
An instance of the output model with the generated data.

src/strands/types/tools.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def process(
249249
system_prompt: Optional[str],
250250
messages: "Messages",
251251
tool_config: ToolConfig,
252-
callback_handler: Any,
253252
kwargs: dict[str, Any],
254253
) -> ToolResult:
255254
"""Process a tool use request and execute the tool.
@@ -260,7 +259,6 @@ def process(
260259
model: The model being used for the conversation.
261260
system_prompt: The system prompt for the conversation.
262261
tool_config: The tool configuration for the current session.
263-
callback_handler: Callback for processing events as they happen.
264262
kwargs: Additional context-specific arguments.
265263
266264
Returns:

tests/strands/agent/test_agent.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,6 @@ def function(system_prompt: str) -> str:
796796
system_prompt="You are a helpful assistant.",
797797
messages=unittest.mock.ANY,
798798
tool_config=unittest.mock.ANY,
799-
callback_handler=unittest.mock.ANY,
800799
kwargs={"system_prompt": "tool prompt"},
801800
)
802801

@@ -1076,18 +1075,10 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac
10761075
mock_tracer.start_agent_span.return_value = mock_span
10771076
mock_get_tracer.return_value = mock_tracer
10781077

1079-
# Define the side effect to simulate callback handler being called multiple times
1080-
def call_callback_handler(*args, **kwargs):
1081-
# Extract the callback handler from kwargs
1082-
callback_handler = kwargs.get("callback_handler")
1083-
# Call the callback handler with different data values
1084-
callback_handler(data="First chunk")
1085-
callback_handler(data="Second chunk")
1086-
callback_handler(data="Final chunk", complete=True)
1087-
# Return expected values from event_loop_cycle
1078+
def test_event_loop(*args, **kwargs):
10881079
yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})}
10891080

1090-
mock_event_loop_cycle.side_effect = call_callback_handler
1081+
mock_event_loop_cycle.side_effect = test_event_loop
10911082

10921083
# Create agent and make a call
10931084
agent = Agent(model=mock_model)

0 commit comments

Comments
 (0)