Skip to content

Commit 64fc331

Browse files
committed
feat: Add hooks for before/after tool calls + allow hooks to update values
Add the ability to intercept/modify tool calls by implementing support for BeforeToolInvocationEvent & AfterToolInvocationEvent hooks
1 parent 46f66be commit 64fc331

File tree

8 files changed

+470
-36
lines changed

8 files changed

+470
-36
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import uuid
1414
from typing import TYPE_CHECKING, Any, AsyncGenerator
1515

16+
from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
1617
from ..telemetry.metrics import Trace
1718
from ..telemetry.tracer import get_tracer
1819
from ..tools.executor import run_tools, validate_and_prepare_tools
@@ -271,47 +272,78 @@ def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGener
271272
The final tool result or an error response if the tool fails or is not found.
272273
"""
273274
logger.debug("tool=<%s> | invoking", tool)
274-
tool_use_id = tool["toolUseId"]
275275
tool_name = tool["name"]
276276

277277
# Get the tool info
278278
tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
279279
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)
280280

281+
# Add standard arguments to kwargs for Python tools
282+
kwargs.update(
283+
{
284+
"model": agent.model,
285+
"system_prompt": agent.system_prompt,
286+
"messages": agent.messages,
287+
"tool_config": agent.tool_config,
288+
}
289+
)
290+
291+
before_event = BeforeToolInvocationEvent(
292+
agent=agent,
293+
selected_tool=tool_func,
294+
tool_use=tool,
295+
kwargs=kwargs,
296+
)
297+
agent._hooks.invoke_callbacks(before_event)
298+
281299
try:
300+
selected_tool = before_event.selected_tool
301+
tool_use = before_event.tool_use
302+
282303
# Check if tool exists
283-
if not tool_func:
304+
if not selected_tool:
284305
logger.error(
285306
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
286307
tool_name,
287308
list(agent.tool_registry.registry.keys()),
288309
)
289310
return {
290-
"toolUseId": tool_use_id,
311+
"toolUseId": str(tool_use.get("toolUseId")),
291312
"status": "error",
292313
"content": [{"text": f"Unknown tool: {tool_name}"}],
293314
}
294-
# Add standard arguments to kwargs for Python tools
295-
kwargs.update(
296-
{
297-
"model": agent.model,
298-
"system_prompt": agent.system_prompt,
299-
"messages": agent.messages,
300-
"tool_config": agent.tool_config,
301-
}
302-
)
303315

304-
result = tool_func.invoke(tool, **kwargs)
305-
yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from
306-
return result
316+
result = selected_tool.invoke(tool_use, **kwargs)
317+
after_event = AfterToolInvocationEvent(
318+
agent=agent,
319+
selected_tool=selected_tool,
320+
tool_use=tool_use,
321+
kwargs=kwargs,
322+
result=result,
323+
)
324+
agent._hooks.invoke_callbacks(after_event)
325+
yield {
326+
"result": after_event.result
327+
} # Placeholder until tool_func becomes a generator from which we can yield from
328+
return after_event.result
307329

308330
except Exception as e:
309331
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
310-
return {
311-
"toolUseId": tool_use_id,
332+
error_result: ToolResult = {
333+
"toolUseId": str(tool_use.get("toolUseId")),
312334
"status": "error",
313335
"content": [{"text": f"Error: {str(e)}"}],
314336
}
337+
after_event = AfterToolInvocationEvent(
338+
agent=agent,
339+
selected_tool=selected_tool,
340+
tool_use=tool_use,
341+
kwargs=kwargs,
342+
result=error_result,
343+
exception=e,
344+
)
345+
agent._hooks.invoke_callbacks(after_event)
346+
return after_event.result
315347

316348

317349
async def _handle_tool_execution(

src/strands/experimental/hooks/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,21 @@ def log_end(self, event: EndRequestEvent) -> None:
2929
type-safe system that supports multiple subscribers per event type.
3030
"""
3131

32-
from .events import AgentInitializedEvent, EndRequestEvent, StartRequestEvent
32+
from .events import (
33+
AfterToolInvocationEvent,
34+
AgentInitializedEvent,
35+
BeforeToolInvocationEvent,
36+
EndRequestEvent,
37+
StartRequestEvent,
38+
)
3339
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry
3440

3541
__all__ = [
3642
"AgentInitializedEvent",
3743
"StartRequestEvent",
3844
"EndRequestEvent",
45+
"BeforeToolInvocationEvent",
46+
"AfterToolInvocationEvent",
3947
"HookEvent",
4048
"HookProvider",
4149
"HookCallback",

src/strands/experimental/hooks/events.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""
55

66
from dataclasses import dataclass
7+
from typing import Any, Optional
78

9+
from ...types.tools import AgentTool, ToolResult, ToolUse
810
from .registry import HookEvent
911

1012

@@ -56,9 +58,63 @@ class EndRequestEvent(HookEvent):
5658

5759
@property
5860
def should_reverse_callbacks(self) -> bool:
59-
"""Return True to invoke callbacks in reverse order for proper cleanup.
61+
"""True to invoke callbacks in reverse order."""
62+
return True
63+
64+
65+
@dataclass
66+
class BeforeToolInvocationEvent(HookEvent):
67+
"""Event triggered before a tool is invoked.
68+
69+
This event is fired just before the agent executes a tool, allowing hook
70+
providers to inspect, modify, or replace the tool that will be executed.
71+
The selected_tool can be modified by hook callbacks to change which tool
72+
gets executed.
73+
74+
Attributes:
75+
selected_tool: The tool that will be invoked. Can be modified by hooks
76+
to change which tool gets executed. This may be None if tool lookup failed.
77+
tool_use: The tool parameters that will be passed to selected_tool.
78+
kwargs: Keyword arguments that will be passed to the tool.
79+
"""
80+
81+
selected_tool: Optional[AgentTool]
82+
tool_use: ToolUse
83+
kwargs: dict[str, Any]
84+
85+
def _can_write(self, name: str) -> bool:
86+
return name in ["selected_tool", "tool_use"]
87+
88+
89+
@dataclass
90+
class AfterToolInvocationEvent(HookEvent):
91+
"""Event triggered after a tool invocation completes.
6092
61-
Returns:
62-
True, indicating callbacks should be invoked in reverse order.
63-
"""
93+
This event is fired after the agent has finished executing a tool,
94+
regardless of whether the execution was successful or resulted in an error.
95+
Hook providers can use this event for cleanup, logging, or post-processing.
96+
97+
Note: This event uses reverse callback ordering, meaning callbacks registered
98+
later will be invoked first during cleanup.
99+
100+
Attributes:
101+
selected_tool: The tool that was invoked. It may be None if tool lookup failed.
102+
tool_use: The tool parameters that were passed to the tool invoked.
103+
kwargs: Keyword arguments that were passed to the tool
104+
result: The result of the tool invocation. Either a ToolResult on success
105+
or an Exception if the tool execution failed.
106+
"""
107+
108+
selected_tool: Optional[AgentTool]
109+
tool_use: ToolUse
110+
kwargs: dict[str, Any]
111+
result: ToolResult
112+
exception: Optional[Exception] = None
113+
114+
def _can_write(self, name: str) -> bool:
115+
return name == "result"
116+
117+
@property
118+
def should_reverse_callbacks(self) -> bool:
119+
"""True to invoke callbacks in reverse order."""
64120
return True

src/strands/experimental/hooks/registry.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass
11-
from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar
11+
from typing import TYPE_CHECKING, Any, Callable, Generator, Generic, Protocol, Type, TypeVar
1212

1313
if TYPE_CHECKING:
1414
from ...agent import Agent
@@ -34,6 +34,40 @@ def should_reverse_callbacks(self) -> bool:
3434
"""
3535
return False
3636

37+
### Code below is infrastructure for disallowing updates to properties ###
38+
### that aren't expected to be updated from hook callbacks. ###
39+
40+
def _can_write(self, name: str) -> bool:
41+
"""Check if the given property can be written to.
42+
43+
Args:
44+
name: The name of the property to check.
45+
46+
Returns:
47+
True if the property can be written to, False otherwise.
48+
"""
49+
return False
50+
51+
def __post_init__(self) -> None:
52+
"""Disallow writes to non-approved properties."""
53+
# This is needed as otherwise the class can't be initialized at all, so we trigger
54+
# this after class initialization
55+
super().__setattr__("_disallow_writes", True)
56+
57+
def __setattr__(self, name: str, value: Any) -> None:
58+
"""Prevent setting attributes on hook events.
59+
60+
Raises:
61+
AttributeError: Always raised to prevent setting attributes on hook events.
62+
"""
63+
# Allow setting attributes:
64+
# - during init (when __dict__) doesn't exist
65+
# - if the subclass specifically said the property is writable
66+
if not hasattr(self, "_disallow_writes") or self._can_write(name):
67+
return super().__setattr__(name, value)
68+
69+
raise AttributeError(f"Property {name} is not writable")
70+
3771

3872
T = TypeVar("T", bound=Callable)
3973
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)

tests/fixtures/mock_hook_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_events(self) -> deque[HookEvent]:
1414

1515
def register_hooks(self, registry: HookRegistry) -> None:
1616
for event_type in self.events_types:
17-
registry.add_callback(event_type, self._add_event)
17+
registry.add_callback(event_type, self.add_event)
1818

19-
def _add_event(self, event: HookEvent) -> None:
19+
def add_event(self, event: HookEvent) -> None:
2020
self.events_received.append(event)

tests/strands/agent/test_agent_hooks.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
1-
import unittest.mock
2-
from unittest.mock import call
1+
from unittest.mock import ANY, Mock, call, patch
32

43
import pytest
54
from pydantic import BaseModel
65

76
import strands
87
from strands import Agent
9-
from strands.experimental.hooks import AgentInitializedEvent, EndRequestEvent, StartRequestEvent
8+
from strands.experimental.hooks import (
9+
AfterToolInvocationEvent,
10+
AgentInitializedEvent,
11+
BeforeToolInvocationEvent,
12+
EndRequestEvent,
13+
StartRequestEvent,
14+
)
1015
from strands.types.content import Messages
1116
from tests.fixtures.mock_hook_provider import MockHookProvider
1217
from tests.fixtures.mocked_model_provider import MockedModelProvider
1318

1419

1520
@pytest.fixture
1621
def hook_provider():
17-
return MockHookProvider([AgentInitializedEvent, StartRequestEvent, EndRequestEvent])
22+
return MockHookProvider(
23+
[AgentInitializedEvent, StartRequestEvent, EndRequestEvent, AfterToolInvocationEvent, BeforeToolInvocationEvent]
24+
)
1825

1926

2027
@pytest.fixture
@@ -71,7 +78,7 @@ class User(BaseModel):
7178
return User(name="Jane Doe", age=30)
7279

7380

74-
@unittest.mock.patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks")
81+
@patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks")
7582
def test_agent__init__hooks(mock_invoke_callbacks):
7683
"""Verify that the AgentInitializedEvent is emitted on Agent construction."""
7784
agent = Agent()
@@ -87,9 +94,19 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
8794
agent("test message")
8895

8996
events = hook_provider.get_events()
90-
assert len(events) == 2
9197

98+
assert len(events) == 4
9299
assert events.popleft() == StartRequestEvent(agent=agent)
100+
assert events.popleft() == BeforeToolInvocationEvent(
101+
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
102+
)
103+
assert events.popleft() == AfterToolInvocationEvent(
104+
agent=agent,
105+
selected_tool=agent_tool,
106+
tool_use=tool_use,
107+
kwargs=ANY,
108+
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
109+
)
93110
assert events.popleft() == EndRequestEvent(agent=agent)
94111

95112

@@ -105,16 +122,26 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
105122
pass
106123

107124
events = hook_provider.get_events()
108-
assert len(events) == 2
109125

126+
assert len(events) == 4
110127
assert events.popleft() == StartRequestEvent(agent=agent)
128+
assert events.popleft() == BeforeToolInvocationEvent(
129+
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
130+
)
131+
assert events.popleft() == AfterToolInvocationEvent(
132+
agent=agent,
133+
selected_tool=agent_tool,
134+
tool_use=tool_use,
135+
kwargs=ANY,
136+
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
137+
)
111138
assert events.popleft() == EndRequestEvent(agent=agent)
112139

113140

114141
def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
115142
"""Verify that the correct hook events are emitted as part of structured_output."""
116143

117-
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
144+
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
118145
agent.structured_output(type(user), "example prompt")
119146

120147
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
@@ -124,7 +151,7 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
124151
async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator):
125152
"""Verify that the correct hook events are emitted as part of structured_output_async."""
126153

127-
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
154+
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
128155
await agent.structured_output_async(type(user), "example prompt")
129156

130157
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]

0 commit comments

Comments
 (0)