Skip to content

Commit 1d0530f

Browse files
Shubhamraut01awsarron
authored andcommitted
feat: Add dynamic system prompt override functionality (strands-agents#108)
1 parent 91f3ecd commit 1d0530f

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

src/strands/agent/agent.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -457,27 +457,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
457457
Returns:
458458
The result of the event loop cycle.
459459
"""
460-
kwargs.pop("agent", None)
461-
kwargs.pop("model", None)
462-
kwargs.pop("system_prompt", None)
463-
kwargs.pop("tool_execution_handler", None)
464-
kwargs.pop("event_loop_metrics", None)
465-
kwargs.pop("callback_handler", None)
466-
kwargs.pop("tool_handler", None)
467-
kwargs.pop("messages", None)
468-
kwargs.pop("tool_config", None)
460+
# Extract parameters with fallbacks to instance values
461+
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
462+
model = kwargs.pop("model", self.model)
463+
tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
464+
event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
465+
callback_handler_override = kwargs.pop("callback_handler", callback_handler)
466+
tool_handler = kwargs.pop("tool_handler", self.tool_handler)
467+
messages = kwargs.pop("messages", self.messages)
468+
tool_config = kwargs.pop("tool_config", self.tool_config)
469+
kwargs.pop("agent", None) # Remove agent to avoid conflicts
469470

470471
try:
471472
# Execute the main event loop cycle
472473
stop_reason, message, metrics, state = event_loop_cycle(
473-
model=self.model,
474-
system_prompt=self.system_prompt,
475-
messages=self.messages, # will be modified by event_loop_cycle
476-
tool_config=self.tool_config,
477-
callback_handler=callback_handler,
478-
tool_handler=self.tool_handler,
479-
tool_execution_handler=self.thread_pool_wrapper,
480-
event_loop_metrics=self.event_loop_metrics,
474+
model=model,
475+
system_prompt=system_prompt,
476+
messages=messages, # will be modified by event_loop_cycle
477+
tool_config=tool_config,
478+
callback_handler=callback_handler_override,
479+
tool_handler=tool_handler,
480+
tool_execution_handler=tool_execution_handler,
481+
event_loop_metrics=event_loop_metrics,
481482
agent=self,
482483
event_loop_parent_span=self.trace_span,
483484
**kwargs,
@@ -488,8 +489,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
488489
except ContextWindowOverflowException as e:
489490
# Try reducing the context size and retrying
490491

491-
self.conversation_manager.reduce_context(self.messages, e=e)
492-
return self._execute_event_loop_cycle(callback_handler, kwargs)
492+
self.conversation_manager.reduce_context(messages, e=e)
493+
return self._execute_event_loop_cycle(callback_handler_override, kwargs)
493494

494495
def _record_tool_execution(
495496
self,

tests/strands/agent/test_agent.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,17 +337,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler,
337337
],
338338
]
339339

340+
override_system_prompt = "Override system prompt"
341+
override_model = unittest.mock.Mock()
342+
override_tool_execution_handler = unittest.mock.Mock()
343+
override_event_loop_metrics = unittest.mock.Mock()
344+
override_callback_handler = unittest.mock.Mock()
345+
override_tool_handler = unittest.mock.Mock()
346+
override_messages = [{"role": "user", "content": [{"text": "override msg"}]}]
347+
override_tool_config = {"test": "config"}
348+
340349
def check_kwargs(some_value, **kwargs):
341350
assert some_value == "a_value"
342351
assert kwargs is not None
352+
assert kwargs["system_prompt"] == override_system_prompt
353+
assert kwargs["model"] == override_model
354+
assert kwargs["tool_execution_handler"] == override_tool_execution_handler
355+
assert kwargs["event_loop_metrics"] == override_event_loop_metrics
356+
assert kwargs["callback_handler"] == override_callback_handler
357+
assert kwargs["tool_handler"] == override_tool_handler
358+
assert kwargs["messages"] == override_messages
359+
assert kwargs["tool_config"] == override_tool_config
360+
assert kwargs["agent"] == agent
343361

344362
# Return expected values from event_loop_cycle
345363
return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}
346364

347365
mock_event_loop_cycle.side_effect = check_kwargs
348366

349-
agent("test message", some_value="a_value")
350-
assert mock_event_loop_cycle.call_count == 1
367+
agent(
368+
"test message",
369+
some_value="a_value",
370+
system_prompt=override_system_prompt,
371+
model=override_model,
372+
tool_execution_handler=override_tool_execution_handler,
373+
event_loop_metrics=override_event_loop_metrics,
374+
callback_handler=override_callback_handler,
375+
tool_handler=override_tool_handler,
376+
messages=override_messages,
377+
tool_config=override_tool_config,
378+
)
379+
380+
mock_event_loop_cycle.assert_called_once()
351381

352382

353383
def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):

0 commit comments

Comments
 (0)