Skip to content

Commit ddf7049

Browse files
committed
chore: extract hook based tests to a separate file
Rather than keeping agent hook tests in the test_agent file, extract them out to a separate file for readability/better-separation-of-concerns. Also updated the hook tests verify each hook-event individually to be more readable/easier to debug.
1 parent 1215b88 commit ddf7049

File tree

4 files changed

+125
-55
lines changed

4 files changed

+125
-55
lines changed

tests/fixtures/mock_hook_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import deque
12
from typing import Type
23

34
from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry
@@ -8,6 +9,9 @@ def __init__(self, event_types: list[Type]):
89
self.events_received = []
910
self.events_types = event_types
1011

12+
def get_events(self) -> deque[HookEvent]:
13+
return deque(self.events_received)
14+
1115
def register_hooks(self, registry: HookRegistry) -> None:
1216
for event_type in self.events_types:
1317
registry.add_callback(event_type, self._add_event)

tests/fixtures/mocked_model_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[s
6767
}
6868
}
6969
}
70-
yield {"contentBlockDelta": {"delta": {"tool_use": {"input": json.dumps(content["toolUse"]["input"])}}}}
70+
yield {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}}
7171
yield {"contentBlockStop": {}}
7272

7373
yield {"messageStop": {"stopReason": stop_reason}}

tests/strands/agent/test_agent.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import os
55
import textwrap
66
import unittest.mock
7-
from unittest.mock import call
87

98
import pytest
109
from pydantic import BaseModel
@@ -14,12 +13,10 @@
1413
from strands.agent import AgentResult
1514
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
1615
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
17-
from strands.experimental.hooks import AgentInitializedEvent, EndRequestEvent, StartRequestEvent
1816
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
1917
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
2018
from strands.types.content import Messages
2119
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException
22-
from tests.fixtures.mock_hook_provider import MockHookProvider
2320

2421

2522
@pytest.fixture
@@ -162,11 +159,6 @@ def tools(request, tool):
162159
return request.param if hasattr(request, "param") else [tool_decorated]
163160

164161

165-
@pytest.fixture
166-
def hook_provider():
167-
return MockHookProvider([AgentInitializedEvent, StartRequestEvent, EndRequestEvent])
168-
169-
170162
@pytest.fixture
171163
def agent(
172164
mock_model,
@@ -178,7 +170,6 @@ def agent(
178170
tool_registry,
179171
tool_decorated,
180172
request,
181-
hook_provider,
182173
):
183174
agent = Agent(
184175
model=mock_model,
@@ -188,9 +179,6 @@ def agent(
188179
tools=tools,
189180
)
190181

191-
# for now, hooks are private
192-
agent._hooks.add_hook(hook_provider)
193-
194182
# Only register the tool directly if tools wasn't parameterized
195183
if not hasattr(request, "param") or request.param is None:
196184
# Create a new function tool directly from the decorated function
@@ -720,48 +708,6 @@ def test_agent__call__callback(mock_model, agent, callback_handler):
720708
)
721709

722710

723-
@unittest.mock.patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks")
724-
def test_agent_hooks__init__(mock_invoke_callbacks):
725-
"""Verify that the AgentInitializedEvent is emitted on Agent construction."""
726-
agent = Agent()
727-
728-
# Verify AgentInitialized event was invoked
729-
mock_invoke_callbacks.assert_called_once()
730-
assert mock_invoke_callbacks.call_args == call(AgentInitializedEvent(agent=agent))
731-
732-
733-
def test_agent_hooks__call__(agent, mock_hook_messages, hook_provider):
734-
"""Verify that the correct hook events are emitted as part of __call__."""
735-
736-
agent("test message")
737-
738-
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
739-
740-
741-
@pytest.mark.asyncio
742-
async def test_agent_hooks_stream_async(agent, mock_hook_messages, hook_provider):
743-
"""Verify that the correct hook events are emitted as part of stream_async."""
744-
iterator = agent.stream_async("test message")
745-
await anext(iterator)
746-
assert hook_provider.events_received == [StartRequestEvent(agent=agent)]
747-
748-
# iterate the rest
749-
async for _ in iterator:
750-
pass
751-
752-
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
753-
754-
755-
def test_agent_hooks_structured_output(agent, mock_hook_messages, hook_provider):
756-
"""Verify that the correct hook events are emitted as part of structured_output."""
757-
758-
expected_user = User(name="Jane Doe", age=30, email="jane@doe.com")
759-
agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}])
760-
agent.structured_output(User, "example prompt")
761-
762-
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
763-
764-
765711
def test_agent_tool(mock_randint, agent):
766712
conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager)
767713
agent.conversation_manager = conversation_manager_spy
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import unittest.mock
2+
from unittest.mock import call
3+
4+
import pytest
5+
from pydantic import BaseModel
6+
7+
import strands
8+
from strands import Agent
9+
from strands.experimental.hooks import AgentInitializedEvent, EndRequestEvent, StartRequestEvent
10+
from strands.types.content import Messages
11+
from tests.fixtures.mock_hook_provider import MockHookProvider
12+
from tests.fixtures.mocked_model_provider import MockedModelProvider
13+
14+
15+
@pytest.fixture
16+
def hook_provider():
17+
return MockHookProvider([AgentInitializedEvent, StartRequestEvent, EndRequestEvent])
18+
19+
20+
@pytest.fixture
21+
def agent_tool():
22+
@strands.tools.tool(name="tool_decorated")
23+
def reverse(random_string: str) -> str:
24+
return random_string[::-1]
25+
26+
return reverse
27+
28+
29+
@pytest.fixture
30+
def tool_use(agent_tool):
31+
return {"name": agent_tool.tool_name, "toolUseId": "123", "input": {"random_string": "I invoked a tool!"}}
32+
33+
34+
@pytest.fixture
35+
def mock_model(tool_use):
36+
agent_messages: Messages = [
37+
{
38+
"role": "assistant",
39+
"content": [{"toolUse": tool_use}],
40+
},
41+
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
42+
]
43+
return MockedModelProvider(agent_messages)
44+
45+
46+
@pytest.fixture
47+
def agent(
48+
mock_model,
49+
hook_provider,
50+
agent_tool,
51+
):
52+
agent = Agent(
53+
model=mock_model,
54+
system_prompt="You are a helpful assistant.",
55+
callback_handler=None,
56+
tools=[agent_tool],
57+
)
58+
59+
# for now, hooks are private
60+
agent._hooks.add_hook(hook_provider)
61+
62+
return agent
63+
64+
65+
# mock the User(name='Jane Doe', age=30)
66+
class User(BaseModel):
67+
"""A user of the system."""
68+
69+
name: str
70+
age: int
71+
72+
73+
@unittest.mock.patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks")
74+
def test_agent_hooks__init__(mock_invoke_callbacks):
75+
"""Verify that the AgentInitializedEvent is emitted on Agent construction."""
76+
agent = Agent()
77+
78+
# Verify AgentInitialized event was invoked
79+
mock_invoke_callbacks.assert_called_once()
80+
assert mock_invoke_callbacks.call_args == call(AgentInitializedEvent(agent=agent))
81+
82+
83+
def test_agent_hooks__call__(agent, hook_provider, agent_tool, tool_use):
84+
"""Verify that the correct hook events are emitted as part of __call__."""
85+
86+
agent("test message")
87+
88+
events = hook_provider.get_events()
89+
assert len(events) == 2
90+
91+
assert events.popleft() == StartRequestEvent(agent=agent)
92+
assert events.popleft() == EndRequestEvent(agent=agent)
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_agent_hooks_stream_async(agent, hook_provider, agent_tool, tool_use):
97+
"""Verify that the correct hook events are emitted as part of stream_async."""
98+
iterator = agent.stream_async("test message")
99+
await anext(iterator)
100+
assert hook_provider.events_received == [StartRequestEvent(agent=agent)]
101+
102+
# iterate the rest
103+
async for _ in iterator:
104+
pass
105+
106+
events = hook_provider.get_events()
107+
assert len(events) == 2
108+
109+
assert events.popleft() == StartRequestEvent(agent=agent)
110+
assert events.popleft() == EndRequestEvent(agent=agent)
111+
112+
113+
def test_agent_hooks_structured_output(agent, hook_provider):
114+
"""Verify that the correct hook events are emitted as part of structured_output."""
115+
116+
expected_user = User(name="Jane Doe", age=30)
117+
agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}])
118+
agent.structured_output(User, "example prompt")
119+
120+
assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]

0 commit comments

Comments
 (0)