Skip to content

Commit 70ede37

Browse files
committed
Merge branch 'main' of https://github.com/strands-agents/sdk-python into multi-modal-so
2 parents 2814565 + 513f32b commit 70ede37

File tree

10 files changed

+305
-34
lines changed

10 files changed

+305
-34
lines changed

src/strands/agent/agent.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@
2020
from pydantic import BaseModel
2121

2222
from ..event_loop.event_loop import event_loop_cycle, run_tool
23-
from ..experimental.hooks import AgentInitializedEvent, EndRequestEvent, HookRegistry, StartRequestEvent
23+
from ..experimental.hooks import (
24+
AgentInitializedEvent,
25+
EndRequestEvent,
26+
HookRegistry,
27+
MessageAddedEvent,
28+
StartRequestEvent,
29+
)
2430
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2531
from ..models.bedrock import BedrockModel
2632
from ..telemetry.metrics import EventLoopMetrics
@@ -427,7 +433,7 @@ async def structured_output_async(
427433
# add the prompt as the last message
428434
if prompt:
429435
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
430-
self.messages.append({"role": "user", "content": content})
436+
self._append_message({"role": "user", "content": content})
431437

432438
events = self.model.structured_output(output_model, self.messages)
433439
async for event in events:
@@ -508,7 +514,7 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene
508514
try:
509515
yield {"callback": {"init_event_loop": True, **kwargs}}
510516

511-
self.messages.append(message)
517+
self._append_message(message)
512518

513519
# Execute the event loop cycle with retry logic for context limits
514520
events = self._execute_event_loop_cycle(kwargs)
@@ -598,10 +604,10 @@ def _record_tool_execution(
598604
}
599605

600606
# Add to message history
601-
messages.append(user_msg)
602-
messages.append(tool_use_msg)
603-
messages.append(tool_result_msg)
604-
messages.append(assistant_msg)
607+
self._append_message(user_msg)
608+
self._append_message(tool_use_msg)
609+
self._append_message(tool_result_msg)
610+
self._append_message(assistant_msg)
605611

606612
def _start_agent_trace_span(self, message: Message) -> None:
607613
"""Starts a trace span for the agent.
@@ -643,3 +649,8 @@ def _end_agent_trace_span(
643649
trace_attributes["error"] = error
644650

645651
self.tracer.end_agent_span(**trace_attributes)
652+
653+
def _append_message(self, message: Message) -> None:
654+
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
655+
self.messages.append(message)
656+
self._hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))

src/strands/event_loop/event_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
1515

1616
from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
17+
from ..experimental.hooks.events import MessageAddedEvent
1718
from ..experimental.hooks.registry import get_registry
1819
from ..telemetry.metrics import Trace
1920
from ..telemetry.tracer import get_tracer
@@ -166,6 +167,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
166167

167168
# Add the response message to the conversation
168169
agent.messages.append(message)
170+
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
169171
yield {"callback": {"message": message}}
170172

171173
# Update metrics
@@ -431,6 +433,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator:
431433
}
432434

433435
agent.messages.append(tool_result_message)
436+
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
434437
yield {"callback": {"message": tool_result_message}}
435438

436439
if cycle_span:

src/strands/experimental/hooks/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,21 @@ def log_end(self, event: EndRequestEvent) -> None:
3434
AgentInitializedEvent,
3535
BeforeToolInvocationEvent,
3636
EndRequestEvent,
37+
MessageAddedEvent,
3738
StartRequestEvent,
3839
)
39-
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry
40+
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry, get_registry
4041

4142
__all__ = [
4243
"AgentInitializedEvent",
4344
"StartRequestEvent",
4445
"EndRequestEvent",
4546
"BeforeToolInvocationEvent",
4647
"AfterToolInvocationEvent",
48+
"MessageAddedEvent",
4749
"HookEvent",
4850
"HookProvider",
4951
"HookCallback",
5052
"HookRegistry",
53+
"get_registry",
5154
]

src/strands/experimental/hooks/events.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Any, Optional
88

9+
from ...types.content import Message
910
from ...types.tools import AgentTool, ToolResult, ToolUse
1011
from .registry import HookEvent
1112

@@ -118,3 +119,22 @@ def _can_write(self, name: str) -> bool:
118119
def should_reverse_callbacks(self) -> bool:
119120
"""True to invoke callbacks in reverse order."""
120121
return True
122+
123+
124+
@dataclass
125+
class MessageAddedEvent(HookEvent):
126+
"""Event triggered when a message is added to the agent's conversation.
127+
128+
This event is fired whenever the agent adds a new message to its internal
129+
message history, including user messages, assistant responses, and tool
130+
results. Hook providers can use this event for logging, monitoring, or
131+
implementing custom message processing logic.
132+
133+
Note: This event is only triggered for messages added by the framework
134+
itself, not for messages manually added by tools or external code.
135+
136+
Attributes:
137+
message: The message that was added to the conversation history.
138+
"""
139+
140+
message: Message

src/strands/tools/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def validate_tool_use_name(tool: ToolUse) -> None:
4848
raise InvalidToolUseNameException(message)
4949

5050
tool_name = tool["name"]
51-
tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_\-]*$"
51+
tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$"
5252
tool_name_max_length = 64
5353
valid_name_pattern = bool(re.match(tool_name_pattern, tool_name))
5454
tool_name_len = len(tool_name)

tests/strands/agent/test_agent_hooks.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,27 @@
1010
AgentInitializedEvent,
1111
BeforeToolInvocationEvent,
1212
EndRequestEvent,
13+
MessageAddedEvent,
1314
StartRequestEvent,
15+
get_registry,
1416
)
1517
from strands.types.content import Messages
18+
from strands.types.tools import ToolResult, ToolUse
1619
from tests.fixtures.mock_hook_provider import MockHookProvider
1720
from tests.fixtures.mocked_model_provider import MockedModelProvider
1821

1922

2023
@pytest.fixture
2124
def hook_provider():
2225
return MockHookProvider(
23-
[AgentInitializedEvent, StartRequestEvent, EndRequestEvent, AfterToolInvocationEvent, BeforeToolInvocationEvent]
26+
[
27+
AgentInitializedEvent,
28+
StartRequestEvent,
29+
EndRequestEvent,
30+
AfterToolInvocationEvent,
31+
BeforeToolInvocationEvent,
32+
MessageAddedEvent,
33+
]
2434
)
2535

2636

@@ -63,8 +73,13 @@ def agent(
6373
tools=[agent_tool],
6474
)
6575

66-
# for now, hooks are private
67-
agent._hooks.add_hook(hook_provider)
76+
hooks = get_registry(agent)
77+
hooks.add_hook(hook_provider)
78+
79+
def assert_message_is_last_message_added(event: MessageAddedEvent):
80+
assert event.agent.messages[-1] == event.message
81+
82+
hooks.add_callback(MessageAddedEvent, assert_message_is_last_message_added)
6883

6984
return agent
7085

@@ -88,15 +103,49 @@ def test_agent__init__hooks(mock_invoke_callbacks):
88103
assert mock_invoke_callbacks.call_args == call(AgentInitializedEvent(agent=agent))
89104

90105

106+
def test_agent_tool_call(agent, hook_provider, agent_tool):
107+
agent.tool.tool_decorated(random_string="a string")
108+
109+
length, events = hook_provider.get_events()
110+
111+
tool_use: ToolUse = {"input": {"random_string": "a string"}, "name": "tool_decorated", "toolUseId": ANY}
112+
result: ToolResult = {"content": [{"text": "gnirts a"}], "status": "success", "toolUseId": ANY}
113+
114+
assert length == 6
115+
116+
assert next(events) == BeforeToolInvocationEvent(
117+
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
118+
)
119+
assert next(events) == AfterToolInvocationEvent(
120+
agent=agent,
121+
selected_tool=agent_tool,
122+
tool_use=tool_use,
123+
kwargs=ANY,
124+
result=result,
125+
)
126+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
127+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
128+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
129+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
130+
131+
assert len(agent.messages) == 4
132+
133+
91134
def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
92135
"""Verify that the correct hook events are emitted as part of __call__."""
93136

94137
agent("test message")
95138

96139
length, events = hook_provider.get_events()
97140

98-
assert length == 4
141+
assert length == 8
142+
99143
assert next(events) == StartRequestEvent(agent=agent)
144+
assert next(events) == MessageAddedEvent(
145+
agent=agent,
146+
message=agent.messages[0],
147+
)
148+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
100149
assert next(events) == BeforeToolInvocationEvent(
101150
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
102151
)
@@ -107,8 +156,12 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
107156
kwargs=ANY,
108157
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
109158
)
159+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
160+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
110161
assert next(events) == EndRequestEvent(agent=agent)
111162

163+
assert len(agent.messages) == 4
164+
112165

113166
@pytest.mark.asyncio
114167
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_use):
@@ -123,9 +176,14 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
123176

124177
length, events = hook_provider.get_events()
125178

126-
assert length == 4
179+
assert length == 8
127180

128181
assert next(events) == StartRequestEvent(agent=agent)
182+
assert next(events) == MessageAddedEvent(
183+
agent=agent,
184+
message=agent.messages[0],
185+
)
186+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
129187
assert next(events) == BeforeToolInvocationEvent(
130188
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
131189
)
@@ -136,16 +194,28 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
136194
kwargs=ANY,
137195
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
138196
)
197+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
198+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
139199
assert next(events) == EndRequestEvent(agent=agent)
140200

201+
assert len(agent.messages) == 4
202+
141203

142204
def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
143205
"""Verify that the correct hook events are emitted as part of structured_output."""
144206

145207
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
146208
agent.structured_output(type(user), "example prompt")
147209

148-
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
210+
length, events = hook_provider.get_events()
211+
212+
assert length == 3
213+
214+
assert next(events) == StartRequestEvent(agent=agent)
215+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
216+
assert next(events) == EndRequestEvent(agent=agent)
217+
218+
assert len(agent.messages) == 1
149219

150220

151221
@pytest.mark.asyncio
@@ -155,4 +225,12 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
155225
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
156226
await agent.structured_output_async(type(user), "example prompt")
157227

158-
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
228+
length, events = hook_provider.get_events()
229+
230+
assert length == 3
231+
232+
assert next(events) == StartRequestEvent(agent=agent)
233+
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
234+
assert next(events) == EndRequestEvent(agent=agent)
235+
236+
assert len(agent.messages) == 1

tests/strands/experimental/hooks/test_events.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import pytest
44

5-
from strands.experimental.hooks.events import (
5+
from strands.experimental.hooks import (
66
AfterToolInvocationEvent,
77
AgentInitializedEvent,
88
BeforeToolInvocationEvent,
99
EndRequestEvent,
10+
MessageAddedEvent,
1011
StartRequestEvent,
1112
)
1213
from strands.types.tools import ToolResult, ToolUse
@@ -49,6 +50,11 @@ def start_request_event(agent):
4950
return StartRequestEvent(agent=agent)
5051

5152

53+
@pytest.fixture
54+
def messaged_added_event(agent):
55+
return MessageAddedEvent(agent=agent, message=Mock())
56+
57+
5258
@pytest.fixture
5359
def end_request_event(agent):
5460
return EndRequestEvent(agent=agent)
@@ -78,6 +84,7 @@ def after_tool_event(agent, tool, tool_use, tool_kwargs, tool_result):
7884
def test_event_should_reverse_callbacks(
7985
initialized_event,
8086
start_request_event,
87+
messaged_added_event,
8188
end_request_event,
8289
before_tool_event,
8390
after_tool_event,
@@ -86,13 +93,22 @@ def test_event_should_reverse_callbacks(
8693

8794
assert initialized_event.should_reverse_callbacks == False # noqa: E712
8895

96+
assert messaged_added_event.should_reverse_callbacks == False # noqa: E712
97+
8998
assert start_request_event.should_reverse_callbacks == False # noqa: E712
9099
assert end_request_event.should_reverse_callbacks == True # noqa: E712
91100

92101
assert before_tool_event.should_reverse_callbacks == False # noqa: E712
93102
assert after_tool_event.should_reverse_callbacks == True # noqa: E712
94103

95104

105+
def test_message_added_event_cannot_write_properties(messaged_added_event):
106+
with pytest.raises(AttributeError, match="Property agent is not writable"):
107+
messaged_added_event.agent = Mock()
108+
with pytest.raises(AttributeError, match="Property message is not writable"):
109+
messaged_added_event.message = {}
110+
111+
96112
def test_before_tool_invocation_event_can_write_properties(before_tool_event):
97113
new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={})
98114
before_tool_event.selected_tool = None # Should not raise

0 commit comments

Comments
 (0)