Skip to content

Commit aaf9715

Browse files
fix(tools): avoid KeyError in direct tool calls with ToolContext (#1213)
--------- Co-authored-by: Dean Schmigelski <dbschmigelski+github@gmail.com>
1 parent eaa6efb commit aaf9715

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

src/strands/tools/executors/_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ async def _stream(
7575

7676
invocation_state.update(
7777
{
78+
"agent": agent,
7879
"model": agent.model,
7980
"messages": agent.messages,
8081
"system_prompt": agent.system_prompt,

tests/strands/tools/executors/test_executor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,21 @@ async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_resul
459459
tru_results = tool_results
460460
exp_results = [exp_events[-1].tool_result]
461461
assert tru_results == exp_results
462+
463+
464+
@pytest.mark.asyncio
465+
async def test_executor_stream_updates_invocation_state_with_agent(
466+
executor, agent, tool_results, invocation_state, weather_tool, alist
467+
):
468+
"""Test that invocation_state is updated with agent reference."""
469+
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
470+
471+
# Start with empty invocation_state to verify agent is added
472+
empty_invocation_state = {}
473+
474+
stream = executor._stream(agent, tool_use, tool_results, empty_invocation_state)
475+
await alist(stream)
476+
477+
# Verify that the invocation_state was updated with the agent
478+
assert "agent" in empty_invocation_state
479+
assert empty_invocation_state["agent"] is agent

tests_integ/test_tool_context_injection.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,22 @@ def test_strands_context_integration_context_custom():
5454
agent("using a tool, write a bad story")
5555

5656
_validate_tool_result_content(agent)
57+
58+
59+
@tool(context=True)
60+
def calculate_sum(a: int, b: int, tool_context: ToolContext) -> int:
61+
result = a + b
62+
tool_context.agent.state.set("last_calculation", result)
63+
return result
64+
65+
66+
def test_agent_state_access_through_tool_context():
67+
"""Test that tools can access agent state through ToolContext."""
68+
agent = Agent(tools=[calculate_sum])
69+
result = agent.tool.calculate_sum(a=1, b=1)
70+
71+
# Verify the tool executed successfully
72+
assert result["status"] == "success"
73+
74+
# Verify the agent state was updated
75+
assert agent.state.get("last_calculation") == 2

0 commit comments

Comments
 (0)