Skip to content

Commit 2a43618

Browse files
committed
covers all 8 overrideable parameters (system_prompt, model, tool_execution_handler, event_loop_metrics, callback_handler, tool_handler, messages, tool_config)
1 parent e1e19de commit 2a43618

File tree

1 file changed

+36
-212
lines changed

1 file changed

+36
-212
lines changed

tests/strands/agent/test_agent.py

Lines changed: 36 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,20 @@
33
import os
44
import textwrap
55
import threading
6-
import time
76
import unittest.mock
87
from time import sleep
98

109
import pytest
1110

1211
import strands
12+
import strands.tools
1313
from strands.agent.agent import Agent
1414
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
1515
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
1616
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
1717
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
1818
from strands.types.content import Messages
1919
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException
20-
from strands.types.models import Model
2120

2221

2322
@pytest.fixture
@@ -163,214 +162,6 @@ def agent(
163162
return agent
164163

165164

166-
def test_agent_system_prompt_overrides_all_cases():
167-
"""Test all system prompt override scenarios and all 8 overrideable parameters.
168-
169-
This comprehensive test ensures that:
170-
1. System prompt overrides work in all scenarios
171-
2. All 8 parameters that can be overridden in _execute_event_loop_cycle are properly handled
172-
3. Prevents future regressions for all override functionality
173-
"""
174-
# Enhanced mock model that tracks all calls and parameters
175-
class ComprehensiveMockModel(Model):
176-
def __init__(self, model_id="mock-model"):
177-
self.model_id = model_id
178-
self.captured_system_prompts = []
179-
self.captured_calls = []
180-
181-
def update_config(self, **model_config):
182-
pass
183-
184-
def get_config(self):
185-
return {"model_id": self.model_id}
186-
187-
def format_request(self, messages, tool_specs=None, system_prompt=None):
188-
self.captured_system_prompts.append(system_prompt)
189-
return {"messages": messages, "tool_specs": tool_specs, "system_prompt": system_prompt}
190-
191-
def format_chunk(self, event):
192-
return {"messageStart": {"role": "assistant"}}
193-
194-
def stream(self, request):
195-
yield {"contentBlockDelta": {"delta": {"text": "Mock response"}}}
196-
yield {"contentBlockStop": {}}
197-
yield {"messageStop": {"stopReason": "end_turn"}}
198-
199-
def converse(self, messages, tool_specs=None, system_prompt=None):
200-
# Call format_request to capture system prompts like the base class does
201-
self.format_request(messages, tool_specs, system_prompt)
202-
203-
self.captured_calls.append({
204-
'system_prompt': system_prompt,
205-
'messages': messages,
206-
'tool_specs': tool_specs,
207-
'kwargs': {}
208-
})
209-
return [
210-
{"contentBlockStart": {"start": {}}},
211-
{"contentBlockDelta": {"delta": {"text": "Test response"}}},
212-
{"contentBlockStop": {}},
213-
{"messageStop": {"stopReason": "end_turn"}},
214-
]
215-
216-
# Mock classes for complex dependencies
217-
class MockToolHandler:
218-
def __init__(self, name):
219-
self.name = name
220-
def get_tools(self):
221-
return []
222-
223-
class MockCallbackHandler:
224-
def __init__(self, name):
225-
self.name = name
226-
227-
def __call__(self, **kwargs):
228-
# Mock callback handler that does nothing
229-
pass
230-
231-
class MockTrace:
232-
def __init__(self, name):
233-
self.name = name
234-
self.id = "mock-trace-id"
235-
def add_child(self, child):
236-
pass
237-
238-
class MockMetrics:
239-
def __init__(self, name):
240-
self.name = name
241-
self.cycle_count = 0
242-
self.cycle_durations = []
243-
self.traces = []
244-
245-
def start_cycle(self):
246-
self.cycle_count += 1
247-
start_time = time.time()
248-
cycle_trace = MockTrace(f"Cycle {self.cycle_count}")
249-
self.traces.append(cycle_trace)
250-
return start_time, cycle_trace
251-
252-
def end_cycle(self, start_time, cycle_trace):
253-
duration = time.time() - start_time
254-
self.cycle_durations.append(duration)
255-
256-
def update_usage(self, usage):
257-
pass
258-
259-
def update_metrics(self, metrics):
260-
pass
261-
262-
class MockExecutor:
263-
def __init__(self, name):
264-
self.name = name
265-
266-
# === PART 1: Test System Prompt Override Scenarios ===
267-
mock_model = ComprehensiveMockModel("system-prompt-test")
268-
269-
# 1. Uses default system prompt
270-
default_prompt = "You are a helpful assistant."
271-
agent = Agent(system_prompt=default_prompt, model=mock_model)
272-
agent("Hello")
273-
assert mock_model.captured_system_prompts[-1] == default_prompt
274-
275-
# 2. Override system prompt per call
276-
override_prompt = "You are a pirate."
277-
agent("Hello", system_prompt=override_prompt)
278-
assert mock_model.captured_system_prompts[-1] == override_prompt
279-
280-
# 3. Reverts to default after override
281-
agent("Hello again")
282-
assert mock_model.captured_system_prompts[-1] == default_prompt
283-
284-
# 4. Multiple overrides
285-
agent("Hi", system_prompt="You are a poet.")
286-
assert mock_model.captured_system_prompts[-1] == "You are a poet."
287-
agent("Hi", system_prompt="You are a robot.")
288-
assert mock_model.captured_system_prompts[-1] == "You are a robot."
289-
agent("Hi")
290-
assert mock_model.captured_system_prompts[-1] == default_prompt
291-
292-
# 5. Override with None
293-
agent("Test", system_prompt=None)
294-
assert mock_model.captured_system_prompts[-1] is None
295-
296-
# 6. Override with empty string
297-
agent("Test", system_prompt="")
298-
assert mock_model.captured_system_prompts[-1] == ""
299-
300-
# 7. No default system prompt
301-
agent2 = Agent(model=mock_model) # No default
302-
agent2("Hello")
303-
assert mock_model.captured_system_prompts[-1] is None
304-
agent2("Hello", system_prompt="You are helpful.")
305-
assert mock_model.captured_system_prompts[-1] == "You are helpful."
306-
307-
# === PART 2: Test All 8 Overrideable Parameters ===
308-
override_model = ComprehensiveMockModel("override-model")
309-
original_model = ComprehensiveMockModel("original-model")
310-
311-
# Create agent with original model
312-
comprehensive_agent = Agent(
313-
model=original_model,
314-
system_prompt="Default system prompt"
315-
)
316-
317-
# Test all 8 overrideable parameters
318-
override_messages = [{"role": "user", "content": [{"text": "Override message"}]}]
319-
override_tool_handler = MockToolHandler("override")
320-
override_callback = MockCallbackHandler("override")
321-
override_metrics = MockMetrics("override")
322-
override_executor = MockExecutor("override")
323-
override_tool_config = {"temperature": 0.8}
324-
325-
# Execute with all overrides
326-
comprehensive_agent(
327-
"Test comprehensive override",
328-
system_prompt="Override system prompt",
329-
model=override_model,
330-
tool_execution_handler=override_executor,
331-
event_loop_metrics=override_metrics,
332-
callback_handler=override_callback,
333-
tool_handler=override_tool_handler,
334-
messages=override_messages,
335-
tool_config=override_tool_config
336-
)
337-
338-
# Verify the overridden model was used
339-
assert len(override_model.captured_calls) == 1
340-
call = override_model.captured_calls[0]
341-
342-
# Verify overrides were applied
343-
assert call['system_prompt'] == "Override system prompt"
344-
assert call['messages'] == override_messages
345-
# Note: tool_config gets processed into tool_specs at the event loop level
346-
# The model's converse method receives tool_specs, not the raw tool_config
347-
assert call['tool_specs'] is None # No tools configured in this test
348-
349-
# Verify original model was not called during override
350-
assert len(original_model.captured_calls) == 0
351-
352-
# Test partial overrides - only override some parameters
353-
mock_model.captured_calls.clear()
354-
agent(
355-
"Another test",
356-
system_prompt="Partial override",
357-
model=mock_model
358-
# Other parameters use defaults
359-
)
360-
361-
assert len(mock_model.captured_calls) == 1
362-
partial_call = mock_model.captured_calls[0]
363-
assert partial_call['system_prompt'] == "Partial override"
364-
365-
# Test no overrides - should use defaults
366-
original_model.captured_calls.clear()
367-
comprehensive_agent("Default test")
368-
369-
assert len(original_model.captured_calls) == 1
370-
default_call = original_model.captured_calls[0]
371-
assert default_call['system_prompt'] == "Default system prompt"
372-
373-
374165
def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry):
375166
_ = tool_registry
376167

@@ -547,17 +338,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler,
547338
],
548339
]
549340

341+
override_system_prompt = "Override system prompt"
342+
override_model = unittest.mock.Mock()
343+
override_tool_execution_handler = unittest.mock.Mock()
344+
override_event_loop_metrics = unittest.mock.Mock()
345+
override_callback_handler = unittest.mock.Mock()
346+
override_tool_handler = unittest.mock.Mock()
347+
override_messages = [{"role": "user", "content": [{"text": "override msg"}]}]
348+
override_tool_config = {"test": "config"}
349+
550350
def check_kwargs(some_value, **kwargs):
551351
assert some_value == "a_value"
552352
assert kwargs is not None
353+
assert kwargs["system_prompt"] == override_system_prompt
354+
assert kwargs["model"] == override_model
355+
assert kwargs["tool_execution_handler"] == override_tool_execution_handler
356+
assert kwargs["event_loop_metrics"] == override_event_loop_metrics
357+
assert kwargs["callback_handler"] == override_callback_handler
358+
assert kwargs["tool_handler"] == override_tool_handler
359+
assert kwargs["messages"] == override_messages
360+
assert kwargs["tool_config"] == override_tool_config
361+
assert kwargs["agent"] == agent
553362

554363
# Return expected values from event_loop_cycle
555364
return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}
556365

557366
mock_event_loop_cycle.side_effect = check_kwargs
558367

559-
agent("test message", some_value="a_value")
560-
assert mock_event_loop_cycle.call_count == 1
368+
agent(
369+
"test message",
370+
some_value="a_value",
371+
system_prompt=override_system_prompt,
372+
model=override_model,
373+
tool_execution_handler=override_tool_execution_handler,
374+
event_loop_metrics=override_event_loop_metrics,
375+
callback_handler=override_callback_handler,
376+
tool_handler=override_tool_handler,
377+
messages=override_messages,
378+
tool_config=override_tool_config,
379+
)
380+
381+
mock_event_loop_cycle.assert_called_once()
561382

562383

563384
def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):
@@ -1186,3 +1007,6 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_
11861007
kwargs = mock_event_loop_cycle.call_args[1]
11871008
assert "event_loop_parent_span" in kwargs
11881009
assert kwargs["event_loop_parent_span"] == mock_span
1010+
1011+
1012+

0 commit comments

Comments
 (0)