Skip to content

Commit 75dbbad

Browse files
authored
stop passing around callback handler (#323)
1 parent bd36b95 commit 75dbbad

File tree

8 files changed

+13
-84
lines changed

8 files changed

+13
-84
lines changed

src/strands/agent/agent.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def caller(
134134
system_prompt=self._agent.system_prompt,
135135
messages=self._agent.messages,
136136
tool_config=self._agent.tool_config,
137-
callback_handler=self._agent.callback_handler,
138137
kwargs=kwargs,
139138
)
140139

@@ -375,7 +374,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
375374
self._start_agent_trace_span(prompt)
376375

377376
try:
378-
events = self._run_loop(callback_handler, prompt, kwargs)
377+
events = self._run_loop(prompt, kwargs)
379378
for event in events:
380379
if "callback" in event:
381380
callback_handler(**event["callback"])
@@ -457,7 +456,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
457456
self._start_agent_trace_span(prompt)
458457

459458
try:
460-
events = self._run_loop(callback_handler, prompt, kwargs)
459+
events = self._run_loop(prompt, kwargs)
461460
for event in events:
462461
if "callback" in event:
463462
callback_handler(**event["callback"])
@@ -472,9 +471,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
472471
self._end_agent_trace_span(error=e)
473472
raise
474473

475-
def _run_loop(
476-
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
477-
) -> Generator[dict[str, Any], None, None]:
474+
def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
478475
"""Execute the agent's event loop with the given prompt and parameters."""
479476
try:
480477
# Extract key parameters
@@ -486,14 +483,12 @@ def _run_loop(
486483
self.messages.append(new_message)
487484

488485
# Execute the event loop cycle with retry logic for context limits
489-
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
486+
yield from self._execute_event_loop_cycle(kwargs)
490487

491488
finally:
492489
self.conversation_manager.apply_management(self)
493490

494-
def _execute_event_loop_cycle(
495-
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
496-
) -> Generator[dict[str, Any], None, None]:
491+
def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
497492
"""Execute the event loop cycle with retry logic for context window limits.
498493
499494
This internal method handles the execution of the event loop cycle and implements
@@ -513,7 +508,6 @@ def _execute_event_loop_cycle(
513508
system_prompt=self.system_prompt,
514509
messages=self.messages, # will be modified by event_loop_cycle
515510
tool_config=self.tool_config,
516-
callback_handler=callback_handler,
517511
tool_handler=self.tool_handler,
518512
tool_execution_handler=self.thread_pool_wrapper,
519513
event_loop_metrics=self.event_loop_metrics,
@@ -524,7 +518,7 @@ def _execute_event_loop_cycle(
524518
except ContextWindowOverflowException as e:
525519
# Try reducing the context size and retrying
526520
self.conversation_manager.reduce_context(self, e=e)
527-
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
521+
yield from self._execute_event_loop_cycle(kwargs)
528522

529523
def _record_tool_execution(
530524
self,

src/strands/event_loop/event_loop.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
import uuid
1414
from functools import partial
15-
from typing import Any, Callable, Generator, Optional, cast
15+
from typing import Any, Generator, Optional, cast
1616

1717
from opentelemetry import trace
1818

@@ -40,7 +40,6 @@ def event_loop_cycle(
4040
system_prompt: Optional[str],
4141
messages: Messages,
4242
tool_config: Optional[ToolConfig],
43-
callback_handler: Callable[..., Any],
4443
tool_handler: Optional[ToolHandler],
4544
tool_execution_handler: Optional[ParallelToolExecutorInterface],
4645
event_loop_metrics: EventLoopMetrics,
@@ -65,7 +64,6 @@ def event_loop_cycle(
6564
system_prompt: System prompt instructions for the model.
6665
messages: Conversation history messages.
6766
tool_config: Configuration for available tools.
68-
callback_handler: Callback for processing events as they happen.
6967
tool_handler: Handler for executing tools.
7068
tool_execution_handler: Optional handler for parallel tool execution.
7169
event_loop_metrics: Metrics tracking object for the event loop.
@@ -212,7 +210,6 @@ def event_loop_cycle(
212210
messages,
213211
tool_config,
214212
tool_handler,
215-
callback_handler,
216213
tool_execution_handler,
217214
event_loop_metrics,
218215
event_loop_parent_span,
@@ -258,7 +255,6 @@ def recurse_event_loop(
258255
system_prompt: Optional[str],
259256
messages: Messages,
260257
tool_config: Optional[ToolConfig],
261-
callback_handler: Callable[..., Any],
262258
tool_handler: Optional[ToolHandler],
263259
tool_execution_handler: Optional[ParallelToolExecutorInterface],
264260
event_loop_metrics: EventLoopMetrics,
@@ -274,7 +270,6 @@ def recurse_event_loop(
274270
system_prompt: System prompt instructions for the model
275271
messages: Conversation history messages
276272
tool_config: Configuration for available tools
277-
callback_handler: Callback for processing events as they happen
278273
tool_handler: Handler for tool execution
279274
tool_execution_handler: Optional handler for parallel tool execution.
280275
event_loop_metrics: Metrics tracking object for the event loop.
@@ -302,7 +297,6 @@ def recurse_event_loop(
302297
system_prompt=system_prompt,
303298
messages=messages,
304299
tool_config=tool_config,
305-
callback_handler=callback_handler,
306300
tool_handler=tool_handler,
307301
tool_execution_handler=tool_execution_handler,
308302
event_loop_metrics=event_loop_metrics,
@@ -321,7 +315,6 @@ def _handle_tool_execution(
321315
messages: Messages,
322316
tool_config: ToolConfig,
323317
tool_handler: ToolHandler,
324-
callback_handler: Callable[..., Any],
325318
tool_execution_handler: Optional[ParallelToolExecutorInterface],
326319
event_loop_metrics: EventLoopMetrics,
327320
event_loop_parent_span: Optional[trace.Span],
@@ -345,7 +338,6 @@ def _handle_tool_execution(
345338
messages (Messages): The conversation history messages.
346339
tool_config (ToolConfig): Configuration for available tools.
347340
tool_handler (ToolHandler): Handler for tool execution.
348-
callback_handler (Callable[..., Any]): Callback for processing events as they happen.
349341
tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
350342
event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
351343
event_loop_parent_span (Any): Span for the parent of this event loop.
@@ -374,7 +366,6 @@ def _handle_tool_execution(
374366
system_prompt=system_prompt,
375367
messages=messages,
376368
tool_config=tool_config,
377-
callback_handler=callback_handler,
378369
kwargs=kwargs,
379370
)
380371

@@ -415,7 +406,6 @@ def _handle_tool_execution(
415406
system_prompt=system_prompt,
416407
messages=messages,
417408
tool_config=tool_config,
418-
callback_handler=callback_handler,
419409
tool_handler=tool_handler,
420410
tool_execution_handler=tool_execution_handler,
421411
event_loop_metrics=event_loop_metrics,

src/strands/handlers/tool_handler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def process(
3434
system_prompt: Optional[str],
3535
messages: Messages,
3636
tool_config: ToolConfig,
37-
callback_handler: Any,
3837
kwargs: dict[str, Any],
3938
) -> Any:
4039
"""Process a tool invocation.
@@ -47,7 +46,6 @@ def process(
4746
system_prompt: The system prompt for the agent.
4847
messages: The conversation history.
4948
tool_config: Configuration for the tool.
50-
callback_handler: Callback for processing events as they happen.
5149
kwargs: Additional keyword arguments passed to the tool.
5250
5351
Returns:
@@ -81,7 +79,6 @@ def process(
8179
"system_prompt": system_prompt,
8280
"messages": messages,
8381
"tool_config": tool_config,
84-
"callback_handler": callback_handler,
8582
}
8683
)
8784

src/strands/models/mistral.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import base64
77
import json
88
import logging
9-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union
9+
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union
1010

1111
from mistralai import Mistral
1212
from pydantic import BaseModel
@@ -471,14 +471,15 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
471471

472472
@override
473473
def structured_output(
474-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
474+
self,
475+
output_model: Type[T],
476+
prompt: Messages,
475477
) -> Generator[dict[str, Union[T, Any]], None, None]:
476478
"""Get structured output from the model.
477479
478480
Args:
479481
output_model: The output model to use for the agent.
480482
prompt: The prompt messages to use for the agent.
481-
callback_handler: Optional callback handler for processing events.
482483
483484
Returns:
484485
An instance of the output model with the generated data.

src/strands/types/tools.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def process(
249249
system_prompt: Optional[str],
250250
messages: "Messages",
251251
tool_config: ToolConfig,
252-
callback_handler: Any,
253252
kwargs: dict[str, Any],
254253
) -> ToolResult:
255254
"""Process a tool use request and execute the tool.
@@ -260,7 +259,6 @@ def process(
260259
model: The model being used for the conversation.
261260
system_prompt: The system prompt for the conversation.
262261
tool_config: The tool configuration for the current session.
263-
callback_handler: Callback for processing events as they happen.
264262
kwargs: Additional context-specific arguments.
265263
266264
Returns:

tests/strands/agent/test_agent.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,6 @@ def function(system_prompt: str) -> str:
796796
system_prompt="You are a helpful assistant.",
797797
messages=unittest.mock.ANY,
798798
tool_config=unittest.mock.ANY,
799-
callback_handler=unittest.mock.ANY,
800799
kwargs={"system_prompt": "tool prompt"},
801800
)
802801

@@ -1076,18 +1075,10 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac
10761075
mock_tracer.start_agent_span.return_value = mock_span
10771076
mock_get_tracer.return_value = mock_tracer
10781077

1079-
# Define the side effect to simulate callback handler being called multiple times
1080-
def call_callback_handler(*args, **kwargs):
1081-
# Extract the callback handler from kwargs
1082-
callback_handler = kwargs.get("callback_handler")
1083-
# Call the callback handler with different data values
1084-
callback_handler(data="First chunk")
1085-
callback_handler(data="Second chunk")
1086-
callback_handler(data="Final chunk", complete=True)
1087-
# Return expected values from event_loop_cycle
1078+
def test_event_loop(*args, **kwargs):
10881079
yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})}
10891080

1090-
mock_event_loop_cycle.side_effect = call_callback_handler
1081+
mock_event_loop_cycle.side_effect = test_event_loop
10911082

10921083
# Create agent and make a call
10931084
agent = Agent(model=mock_model)

0 commit comments

Comments
 (0)