|
3 | 3 | import os
|
4 | 4 | import textwrap
|
5 | 5 | import threading
|
6 |
| -import time |
7 | 6 | import unittest.mock
|
8 | 7 | from time import sleep
|
9 | 8 |
|
10 | 9 | import pytest
|
11 | 10 |
|
12 | 11 | import strands
|
| 12 | +import strands.tools |
13 | 13 | from strands.agent.agent import Agent
|
14 | 14 | from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
|
15 | 15 | from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
|
16 | 16 | from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
|
17 | 17 | from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
|
18 | 18 | from strands.types.content import Messages
|
19 | 19 | from strands.types.exceptions import ContextWindowOverflowException, EventLoopException
|
20 |
| -from strands.types.models import Model |
21 | 20 |
|
22 | 21 |
|
23 | 22 | @pytest.fixture
|
@@ -163,214 +162,6 @@ def agent(
|
163 | 162 | return agent
|
164 | 163 |
|
165 | 164 |
|
166 |
| -def test_agent_system_prompt_overrides_all_cases(): |
167 |
| - """Test all system prompt override scenarios and all 8 overrideable parameters. |
168 |
| - |
169 |
| - This comprehensive test ensures that: |
170 |
| - 1. System prompt overrides work in all scenarios |
171 |
| - 2. All 8 parameters that can be overridden in _execute_event_loop_cycle are properly handled |
172 |
| - 3. Prevents future regressions for all override functionality |
173 |
| - """ |
174 |
| - # Enhanced mock model that tracks all calls and parameters |
175 |
| - class ComprehensiveMockModel(Model): |
176 |
| - def __init__(self, model_id="mock-model"): |
177 |
| - self.model_id = model_id |
178 |
| - self.captured_system_prompts = [] |
179 |
| - self.captured_calls = [] |
180 |
| - |
181 |
| - def update_config(self, **model_config): |
182 |
| - pass |
183 |
| - |
184 |
| - def get_config(self): |
185 |
| - return {"model_id": self.model_id} |
186 |
| - |
187 |
| - def format_request(self, messages, tool_specs=None, system_prompt=None): |
188 |
| - self.captured_system_prompts.append(system_prompt) |
189 |
| - return {"messages": messages, "tool_specs": tool_specs, "system_prompt": system_prompt} |
190 |
| - |
191 |
| - def format_chunk(self, event): |
192 |
| - return {"messageStart": {"role": "assistant"}} |
193 |
| - |
194 |
| - def stream(self, request): |
195 |
| - yield {"contentBlockDelta": {"delta": {"text": "Mock response"}}} |
196 |
| - yield {"contentBlockStop": {}} |
197 |
| - yield {"messageStop": {"stopReason": "end_turn"}} |
198 |
| - |
199 |
| - def converse(self, messages, tool_specs=None, system_prompt=None): |
200 |
| - # Call format_request to capture system prompts like the base class does |
201 |
| - self.format_request(messages, tool_specs, system_prompt) |
202 |
| - |
203 |
| - self.captured_calls.append({ |
204 |
| - 'system_prompt': system_prompt, |
205 |
| - 'messages': messages, |
206 |
| - 'tool_specs': tool_specs, |
207 |
| - 'kwargs': {} |
208 |
| - }) |
209 |
| - return [ |
210 |
| - {"contentBlockStart": {"start": {}}}, |
211 |
| - {"contentBlockDelta": {"delta": {"text": "Test response"}}}, |
212 |
| - {"contentBlockStop": {}}, |
213 |
| - {"messageStop": {"stopReason": "end_turn"}}, |
214 |
| - ] |
215 |
| - |
216 |
| - # Mock classes for complex dependencies |
217 |
| - class MockToolHandler: |
218 |
| - def __init__(self, name): |
219 |
| - self.name = name |
220 |
| - def get_tools(self): |
221 |
| - return [] |
222 |
| - |
223 |
| - class MockCallbackHandler: |
224 |
| - def __init__(self, name): |
225 |
| - self.name = name |
226 |
| - |
227 |
| - def __call__(self, **kwargs): |
228 |
| - # Mock callback handler that does nothing |
229 |
| - pass |
230 |
| - |
231 |
| - class MockTrace: |
232 |
| - def __init__(self, name): |
233 |
| - self.name = name |
234 |
| - self.id = "mock-trace-id" |
235 |
| - def add_child(self, child): |
236 |
| - pass |
237 |
| - |
238 |
| - class MockMetrics: |
239 |
| - def __init__(self, name): |
240 |
| - self.name = name |
241 |
| - self.cycle_count = 0 |
242 |
| - self.cycle_durations = [] |
243 |
| - self.traces = [] |
244 |
| - |
245 |
| - def start_cycle(self): |
246 |
| - self.cycle_count += 1 |
247 |
| - start_time = time.time() |
248 |
| - cycle_trace = MockTrace(f"Cycle {self.cycle_count}") |
249 |
| - self.traces.append(cycle_trace) |
250 |
| - return start_time, cycle_trace |
251 |
| - |
252 |
| - def end_cycle(self, start_time, cycle_trace): |
253 |
| - duration = time.time() - start_time |
254 |
| - self.cycle_durations.append(duration) |
255 |
| - |
256 |
| - def update_usage(self, usage): |
257 |
| - pass |
258 |
| - |
259 |
| - def update_metrics(self, metrics): |
260 |
| - pass |
261 |
| - |
262 |
| - class MockExecutor: |
263 |
| - def __init__(self, name): |
264 |
| - self.name = name |
265 |
| - |
266 |
| - # === PART 1: Test System Prompt Override Scenarios === |
267 |
| - mock_model = ComprehensiveMockModel("system-prompt-test") |
268 |
| - |
269 |
| - # 1. Uses default system prompt |
270 |
| - default_prompt = "You are a helpful assistant." |
271 |
| - agent = Agent(system_prompt=default_prompt, model=mock_model) |
272 |
| - agent("Hello") |
273 |
| - assert mock_model.captured_system_prompts[-1] == default_prompt |
274 |
| - |
275 |
| - # 2. Override system prompt per call |
276 |
| - override_prompt = "You are a pirate." |
277 |
| - agent("Hello", system_prompt=override_prompt) |
278 |
| - assert mock_model.captured_system_prompts[-1] == override_prompt |
279 |
| - |
280 |
| - # 3. Reverts to default after override |
281 |
| - agent("Hello again") |
282 |
| - assert mock_model.captured_system_prompts[-1] == default_prompt |
283 |
| - |
284 |
| - # 4. Multiple overrides |
285 |
| - agent("Hi", system_prompt="You are a poet.") |
286 |
| - assert mock_model.captured_system_prompts[-1] == "You are a poet." |
287 |
| - agent("Hi", system_prompt="You are a robot.") |
288 |
| - assert mock_model.captured_system_prompts[-1] == "You are a robot." |
289 |
| - agent("Hi") |
290 |
| - assert mock_model.captured_system_prompts[-1] == default_prompt |
291 |
| - |
292 |
| - # 5. Override with None |
293 |
| - agent("Test", system_prompt=None) |
294 |
| - assert mock_model.captured_system_prompts[-1] is None |
295 |
| - |
296 |
| - # 6. Override with empty string |
297 |
| - agent("Test", system_prompt="") |
298 |
| - assert mock_model.captured_system_prompts[-1] == "" |
299 |
| - |
300 |
| - # 7. No default system prompt |
301 |
| - agent2 = Agent(model=mock_model) # No default |
302 |
| - agent2("Hello") |
303 |
| - assert mock_model.captured_system_prompts[-1] is None |
304 |
| - agent2("Hello", system_prompt="You are helpful.") |
305 |
| - assert mock_model.captured_system_prompts[-1] == "You are helpful." |
306 |
| - |
307 |
| - # === PART 2: Test All 8 Overrideable Parameters === |
308 |
| - override_model = ComprehensiveMockModel("override-model") |
309 |
| - original_model = ComprehensiveMockModel("original-model") |
310 |
| - |
311 |
| - # Create agent with original model |
312 |
| - comprehensive_agent = Agent( |
313 |
| - model=original_model, |
314 |
| - system_prompt="Default system prompt" |
315 |
| - ) |
316 |
| - |
317 |
| - # Test all 8 overrideable parameters |
318 |
| - override_messages = [{"role": "user", "content": [{"text": "Override message"}]}] |
319 |
| - override_tool_handler = MockToolHandler("override") |
320 |
| - override_callback = MockCallbackHandler("override") |
321 |
| - override_metrics = MockMetrics("override") |
322 |
| - override_executor = MockExecutor("override") |
323 |
| - override_tool_config = {"temperature": 0.8} |
324 |
| - |
325 |
| - # Execute with all overrides |
326 |
| - comprehensive_agent( |
327 |
| - "Test comprehensive override", |
328 |
| - system_prompt="Override system prompt", |
329 |
| - model=override_model, |
330 |
| - tool_execution_handler=override_executor, |
331 |
| - event_loop_metrics=override_metrics, |
332 |
| - callback_handler=override_callback, |
333 |
| - tool_handler=override_tool_handler, |
334 |
| - messages=override_messages, |
335 |
| - tool_config=override_tool_config |
336 |
| - ) |
337 |
| - |
338 |
| - # Verify the overridden model was used |
339 |
| - assert len(override_model.captured_calls) == 1 |
340 |
| - call = override_model.captured_calls[0] |
341 |
| - |
342 |
| - # Verify overrides were applied |
343 |
| - assert call['system_prompt'] == "Override system prompt" |
344 |
| - assert call['messages'] == override_messages |
345 |
| - # Note: tool_config gets processed into tool_specs at the event loop level |
346 |
| - # The model's converse method receives tool_specs, not the raw tool_config |
347 |
| - assert call['tool_specs'] is None # No tools configured in this test |
348 |
| - |
349 |
| - # Verify original model was not called during override |
350 |
| - assert len(original_model.captured_calls) == 0 |
351 |
| - |
352 |
| - # Test partial overrides - only override some parameters |
353 |
| - mock_model.captured_calls.clear() |
354 |
| - agent( |
355 |
| - "Another test", |
356 |
| - system_prompt="Partial override", |
357 |
| - model=mock_model |
358 |
| - # Other parameters use defaults |
359 |
| - ) |
360 |
| - |
361 |
| - assert len(mock_model.captured_calls) == 1 |
362 |
| - partial_call = mock_model.captured_calls[0] |
363 |
| - assert partial_call['system_prompt'] == "Partial override" |
364 |
| - |
365 |
| - # Test no overrides - should use defaults |
366 |
| - original_model.captured_calls.clear() |
367 |
| - comprehensive_agent("Default test") |
368 |
| - |
369 |
| - assert len(original_model.captured_calls) == 1 |
370 |
| - default_call = original_model.captured_calls[0] |
371 |
| - assert default_call['system_prompt'] == "Default system prompt" |
372 |
| - |
373 |
| - |
374 | 165 | def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry):
|
375 | 166 | _ = tool_registry
|
376 | 167 |
|
@@ -547,17 +338,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler,
|
547 | 338 | ],
|
548 | 339 | ]
|
549 | 340 |
|
| 341 | + override_system_prompt = "Override system prompt" |
| 342 | + override_model = unittest.mock.Mock() |
| 343 | + override_tool_execution_handler = unittest.mock.Mock() |
| 344 | + override_event_loop_metrics = unittest.mock.Mock() |
| 345 | + override_callback_handler = unittest.mock.Mock() |
| 346 | + override_tool_handler = unittest.mock.Mock() |
| 347 | + override_messages = [{"role": "user", "content": [{"text": "override msg"}]}] |
| 348 | + override_tool_config = {"test": "config"} |
| 349 | + |
550 | 350 | def check_kwargs(some_value, **kwargs):
|
551 | 351 | assert some_value == "a_value"
|
552 | 352 | assert kwargs is not None
|
| 353 | + assert kwargs["system_prompt"] == override_system_prompt |
| 354 | + assert kwargs["model"] == override_model |
| 355 | + assert kwargs["tool_execution_handler"] == override_tool_execution_handler |
| 356 | + assert kwargs["event_loop_metrics"] == override_event_loop_metrics |
| 357 | + assert kwargs["callback_handler"] == override_callback_handler |
| 358 | + assert kwargs["tool_handler"] == override_tool_handler |
| 359 | + assert kwargs["messages"] == override_messages |
| 360 | + assert kwargs["tool_config"] == override_tool_config |
| 361 | + assert kwargs["agent"] == agent |
553 | 362 |
|
554 | 363 | # Return expected values from event_loop_cycle
|
555 | 364 | return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}
|
556 | 365 |
|
557 | 366 | mock_event_loop_cycle.side_effect = check_kwargs
|
558 | 367 |
|
559 |
| - agent("test message", some_value="a_value") |
560 |
| - assert mock_event_loop_cycle.call_count == 1 |
| 368 | + agent( |
| 369 | + "test message", |
| 370 | + some_value="a_value", |
| 371 | + system_prompt=override_system_prompt, |
| 372 | + model=override_model, |
| 373 | + tool_execution_handler=override_tool_execution_handler, |
| 374 | + event_loop_metrics=override_event_loop_metrics, |
| 375 | + callback_handler=override_callback_handler, |
| 376 | + tool_handler=override_tool_handler, |
| 377 | + messages=override_messages, |
| 378 | + tool_config=override_tool_config, |
| 379 | + ) |
| 380 | + |
| 381 | + mock_event_loop_cycle.assert_called_once() |
561 | 382 |
|
562 | 383 |
|
563 | 384 | def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):
|
@@ -1186,3 +1007,6 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_
|
1186 | 1007 | kwargs = mock_event_loop_cycle.call_args[1]
|
1187 | 1008 | assert "event_loop_parent_span" in kwargs
|
1188 | 1009 | assert kwargs["event_loop_parent_span"] == mock_span
|
| 1010 | + |
| 1011 | + |
| 1012 | + |
0 commit comments