1010 AgentInitializedEvent ,
1111 BeforeToolInvocationEvent ,
1212 EndRequestEvent ,
13+ MessageAddedEvent ,
1314 StartRequestEvent ,
15+ get_registry ,
1416)
1517from strands .types .content import Messages
18+ from strands .types .tools import ToolResult , ToolUse
1619from tests .fixtures .mock_hook_provider import MockHookProvider
1720from tests .fixtures .mocked_model_provider import MockedModelProvider
1821
1922
2023@pytest .fixture
2124def hook_provider ():
2225 return MockHookProvider (
23- [AgentInitializedEvent , StartRequestEvent , EndRequestEvent , AfterToolInvocationEvent , BeforeToolInvocationEvent ]
26+ [
27+ AgentInitializedEvent ,
28+ StartRequestEvent ,
29+ EndRequestEvent ,
30+ AfterToolInvocationEvent ,
31+ BeforeToolInvocationEvent ,
32+ MessageAddedEvent ,
33+ ]
2434 )
2535
2636
@@ -63,8 +73,13 @@ def agent(
6373 tools = [agent_tool ],
6474 )
6575
66- # for now, hooks are private
67- agent ._hooks .add_hook (hook_provider )
76+ hooks = get_registry (agent )
77+ hooks .add_hook (hook_provider )
78+
79+ def assert_message_is_last_message_added (event : MessageAddedEvent ):
80+ assert event .agent .messages [- 1 ] == event .message
81+
82+ hooks .add_callback (MessageAddedEvent , assert_message_is_last_message_added )
6883
6984 return agent
7085
@@ -88,15 +103,49 @@ def test_agent__init__hooks(mock_invoke_callbacks):
88103 assert mock_invoke_callbacks .call_args == call (AgentInitializedEvent (agent = agent ))
89104
90105
106+ def test_agent_tool_call (agent , hook_provider , agent_tool ):
107+ agent .tool .tool_decorated (random_string = "a string" )
108+
109+ length , events = hook_provider .get_events ()
110+
111+ tool_use : ToolUse = {"input" : {"random_string" : "a string" }, "name" : "tool_decorated" , "toolUseId" : ANY }
112+ result : ToolResult = {"content" : [{"text" : "gnirts a" }], "status" : "success" , "toolUseId" : ANY }
113+
114+ assert length == 6
115+
116+ assert next (events ) == BeforeToolInvocationEvent (
117+ agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
118+ )
119+ assert next (events ) == AfterToolInvocationEvent (
120+ agent = agent ,
121+ selected_tool = agent_tool ,
122+ tool_use = tool_use ,
123+ kwargs = ANY ,
124+ result = result ,
125+ )
126+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [0 ])
127+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
128+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
129+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
130+
131+ assert len (agent .messages ) == 4
132+
133+
91134def test_agent__call__hooks (agent , hook_provider , agent_tool , tool_use ):
92135 """Verify that the correct hook events are emitted as part of __call__."""
93136
94137 agent ("test message" )
95138
96139 length , events = hook_provider .get_events ()
97140
98- assert length == 4
141+ assert length == 8
142+
99143 assert next (events ) == StartRequestEvent (agent = agent )
144+ assert next (events ) == MessageAddedEvent (
145+ agent = agent ,
146+ message = agent .messages [0 ],
147+ )
148+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
100149 assert next (events ) == BeforeToolInvocationEvent (
101150 agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
102151 )
@@ -107,8 +156,12 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
107156 kwargs = ANY ,
108157 result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
109158 )
159+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
160+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
110161 assert next (events ) == EndRequestEvent (agent = agent )
111162
163+ assert len (agent .messages ) == 4
164+
112165
113166@pytest .mark .asyncio
114167async def test_agent_stream_async_hooks (agent , hook_provider , agent_tool , tool_use ):
@@ -123,9 +176,14 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
123176
124177 length , events = hook_provider .get_events ()
125178
126- assert length == 4
179+ assert length == 8
127180
128181 assert next (events ) == StartRequestEvent (agent = agent )
182+ assert next (events ) == MessageAddedEvent (
183+ agent = agent ,
184+ message = agent .messages [0 ],
185+ )
186+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
129187 assert next (events ) == BeforeToolInvocationEvent (
130188 agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
131189 )
@@ -136,16 +194,28 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
136194 kwargs = ANY ,
137195 result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
138196 )
197+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
198+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
139199 assert next (events ) == EndRequestEvent (agent = agent )
140200
201+ assert len (agent .messages ) == 4
202+
141203
142204def test_agent_structured_output_hooks (agent , hook_provider , user , agenerator ):
143205 """Verify that the correct hook events are emitted as part of structured_output."""
144206
145207 agent .model .structured_output = Mock (return_value = agenerator ([{"output" : user }]))
146208 agent .structured_output (type (user ), "example prompt" )
147209
148- assert hook_provider .events_received == [StartRequestEvent (agent = agent ), EndRequestEvent (agent = agent )]
210+ length , events = hook_provider .get_events ()
211+
212+ assert length == 3
213+
214+ assert next (events ) == StartRequestEvent (agent = agent )
215+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [0 ])
216+ assert next (events ) == EndRequestEvent (agent = agent )
217+
218+ assert len (agent .messages ) == 1
149219
150220
151221@pytest .mark .asyncio
@@ -155,4 +225,12 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
155225 agent .model .structured_output = Mock (return_value = agenerator ([{"output" : user }]))
156226 await agent .structured_output_async (type (user ), "example prompt" )
157227
158- assert hook_provider .events_received == [StartRequestEvent (agent = agent ), EndRequestEvent (agent = agent )]
228+ length , events = hook_provider .get_events ()
229+
230+ assert length == 3
231+
232+ assert next (events ) == StartRequestEvent (agent = agent )
233+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [0 ])
234+ assert next (events ) == EndRequestEvent (agent = agent )
235+
236+ assert len (agent .messages ) == 1
0 commit comments