10
10
AgentInitializedEvent ,
11
11
BeforeToolInvocationEvent ,
12
12
EndRequestEvent ,
13
+ MessageAddedEvent ,
13
14
StartRequestEvent ,
15
+ get_registry ,
14
16
)
15
17
from strands .types .content import Messages
18
+ from strands .types .tools import ToolResult , ToolUse
16
19
from tests .fixtures .mock_hook_provider import MockHookProvider
17
20
from tests .fixtures .mocked_model_provider import MockedModelProvider
18
21
19
22
20
23
@pytest .fixture
21
24
def hook_provider ():
22
25
return MockHookProvider (
23
- [AgentInitializedEvent , StartRequestEvent , EndRequestEvent , AfterToolInvocationEvent , BeforeToolInvocationEvent ]
26
+ [
27
+ AgentInitializedEvent ,
28
+ StartRequestEvent ,
29
+ EndRequestEvent ,
30
+ AfterToolInvocationEvent ,
31
+ BeforeToolInvocationEvent ,
32
+ MessageAddedEvent ,
33
+ ]
24
34
)
25
35
26
36
@@ -63,8 +73,13 @@ def agent(
63
73
tools = [agent_tool ],
64
74
)
65
75
66
- # for now, hooks are private
67
- agent ._hooks .add_hook (hook_provider )
76
+ hooks = get_registry (agent )
77
+ hooks .add_hook (hook_provider )
78
+
79
+ def assert_message_is_last_message_added (event : MessageAddedEvent ):
80
+ assert event .agent .messages [- 1 ] == event .message
81
+
82
+ hooks .add_callback (MessageAddedEvent , assert_message_is_last_message_added )
68
83
69
84
return agent
70
85
@@ -88,15 +103,49 @@ def test_agent__init__hooks(mock_invoke_callbacks):
88
103
assert mock_invoke_callbacks .call_args == call (AgentInitializedEvent (agent = agent ))
89
104
90
105
106
+ def test_agent_tool_call (agent , hook_provider , agent_tool ):
107
+ agent .tool .tool_decorated (random_string = "a string" )
108
+
109
+ length , events = hook_provider .get_events ()
110
+
111
+ tool_use : ToolUse = {"input" : {"random_string" : "a string" }, "name" : "tool_decorated" , "toolUseId" : ANY }
112
+ result : ToolResult = {"content" : [{"text" : "gnirts a" }], "status" : "success" , "toolUseId" : ANY }
113
+
114
+ assert length == 6
115
+
116
+ assert next (events ) == BeforeToolInvocationEvent (
117
+ agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
118
+ )
119
+ assert next (events ) == AfterToolInvocationEvent (
120
+ agent = agent ,
121
+ selected_tool = agent_tool ,
122
+ tool_use = tool_use ,
123
+ kwargs = ANY ,
124
+ result = result ,
125
+ )
126
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [0 ])
127
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
128
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
129
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
130
+
131
+ assert len (agent .messages ) == 4
132
+
133
+
91
134
def test_agent__call__hooks (agent , hook_provider , agent_tool , tool_use ):
92
135
"""Verify that the correct hook events are emitted as part of __call__."""
93
136
94
137
agent ("test message" )
95
138
96
139
length , events = hook_provider .get_events ()
97
140
98
- assert length == 4
141
+ assert length == 8
142
+
99
143
assert next (events ) == StartRequestEvent (agent = agent )
144
+ assert next (events ) == MessageAddedEvent (
145
+ agent = agent ,
146
+ message = agent .messages [0 ],
147
+ )
148
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
100
149
assert next (events ) == BeforeToolInvocationEvent (
101
150
agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
102
151
)
@@ -107,8 +156,12 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
107
156
kwargs = ANY ,
108
157
result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
109
158
)
159
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
160
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
110
161
assert next (events ) == EndRequestEvent (agent = agent )
111
162
163
+ assert len (agent .messages ) == 4
164
+
112
165
113
166
@pytest .mark .asyncio
114
167
async def test_agent_stream_async_hooks (agent , hook_provider , agent_tool , tool_use ):
@@ -123,9 +176,14 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
123
176
124
177
length , events = hook_provider .get_events ()
125
178
126
- assert length == 4
179
+ assert length == 8
127
180
128
181
assert next (events ) == StartRequestEvent (agent = agent )
182
+ assert next (events ) == MessageAddedEvent (
183
+ agent = agent ,
184
+ message = agent .messages [0 ],
185
+ )
186
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
129
187
assert next (events ) == BeforeToolInvocationEvent (
130
188
agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
131
189
)
@@ -136,16 +194,28 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
136
194
kwargs = ANY ,
137
195
result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
138
196
)
197
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
198
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
139
199
assert next (events ) == EndRequestEvent (agent = agent )
140
200
201
+ assert len (agent .messages ) == 4
202
+
141
203
142
204
def test_agent_structured_output_hooks (agent , hook_provider , user , agenerator ):
143
205
"""Verify that the correct hook events are emitted as part of structured_output."""
144
206
145
207
agent .model .structured_output = Mock (return_value = agenerator ([{"output" : user }]))
146
208
agent .structured_output (type (user ), "example prompt" )
147
209
148
- assert hook_provider .events_received == [StartRequestEvent (agent = agent ), EndRequestEvent (agent = agent )]
210
+ length , events = hook_provider .get_events ()
211
+
212
+ assert length == 3
213
+
214
+ assert next (events ) == StartRequestEvent (agent = agent )
215
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [0 ])
216
+ assert next (events ) == EndRequestEvent (agent = agent )
217
+
218
+ assert len (agent .messages ) == 1
149
219
150
220
151
221
@pytest .mark .asyncio
@@ -155,4 +225,12 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
155
225
agent .model .structured_output = Mock (return_value = agenerator ([{"output" : user }]))
156
226
await agent .structured_output_async (type (user ), "example prompt" )
157
227
158
- assert hook_provider .events_received == [StartRequestEvent (agent = agent ), EndRequestEvent (agent = agent )]
228
+ length , events = hook_provider .get_events ()
229
+
230
+ assert length == 3
231
+
232
+ assert next (events ) == StartRequestEvent (agent = agent )
233
+ assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [0 ])
234
+ assert next (events ) == EndRequestEvent (agent = agent )
235
+
236
+ assert len (agent .messages ) == 1
0 commit comments