Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,18 @@ class BeforeToolCallEvent(HookEvent):
to change which tool gets executed. This may be None if tool lookup failed.
tool_use: The tool parameters that will be passed to selected_tool.
invocation_state: Keyword arguments that will be passed to the tool.
cancel_tool: A user defined message that when set, will cancel the tool call.
The message will be placed into a tool result with an error status. Alternatively, customers can set the
field to `True` and Strands will populate a default cancel message.
"""

selected_tool: Optional[AgentTool]
tool_use: ToolUse
invocation_state: dict[str, Any]
cancel_tool: bool | str = False

def _can_write(self, name: str) -> bool:
return name in ["selected_tool", "tool_use"]
return name in ["cancel_tool", "selected_tool", "tool_use"]


@dataclass
Expand Down
28 changes: 27 additions & 1 deletion src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,32 @@ async def _stream(
)
)

if before_event.cancel_tool:
after_event = agent.hooks.invoke_callbacks(
AfterToolCallEvent(
agent=agent,
tool_use=tool_use,
invocation_state=invocation_state,
result={
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [
{
"text": (
before_event.cancel_tool
if isinstance(before_event.cancel_tool, str)
else "tool cancelled by user"
),
},
],
},
selected_tool=None,
)
)
yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)
return

try:
selected_tool = before_event.selected_tool
tool_use = before_event.tool_use
Expand Down Expand Up @@ -123,7 +149,7 @@ async def _stream(
# so that we don't needlessly yield ToolStreamEvents for non-generator callbacks.
# In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent
# we yield it directly; all other cases (non-sdk AgentTools), we wrap events in
# ToolStreamEvent and the last even is just the result
# ToolStreamEvent and the last event is just the result.

if isinstance(event, ToolResultEvent):
# below the last "event" must point to the tool_result
Expand Down
38 changes: 38 additions & 0 deletions tests/strands/tools/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ def tracer():
yield mock_get_tracer.return_value


@pytest.fixture
def cancel_hook(request):
def callback(event):
event.cancel_tool = request.param
return event

return callback


@pytest.mark.asyncio
async def test_executor_stream_yields_result(
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist
Expand Down Expand Up @@ -215,3 +224,32 @@ async def test_executor_stream_with_trace(

cycle_trace.add_child.assert_called_once()
assert isinstance(cycle_trace.add_child.call_args[0][0], Trace)


@pytest.mark.parametrize(
("cancel_hook", "result_text"),
[(True, "tool cancelled by user"), ("user cancel message", "user cancel message")],
indirect=["cancel_hook"],
)
@pytest.mark.asyncio
async def test_executor_stream_cancel(cancel_hook, result_text, executor, agent, tool_results, invocation_state, alist):
agent.hooks.add_callback(BeforeToolCallEvent, cancel_hook)
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}

stream = executor._stream(agent, tool_use, tool_results, invocation_state)

tru_events = await alist(stream)
exp_events = [
ToolResultEvent(
{
"toolUseId": "1",
"status": "error",
"content": [{"text": result_text}],
},
),
]
assert tru_events == exp_events

tru_results = tool_results
exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results
15 changes: 15 additions & 0 deletions tests_integ/tools/executors/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from strands.hooks import BeforeToolCallEvent, HookProvider


@pytest.fixture
def cancel_hook():
class Hook(HookProvider):
def register_hooks(self, registry):
registry.add_callback(BeforeToolCallEvent, self.cancel)

def cancel(self, event):
event.cancel_tool = "cancelled tool call"

return Hook()
12 changes: 12 additions & 0 deletions tests_integ/tools/executors/test_concurrent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json

import pytest

Expand Down Expand Up @@ -59,3 +60,14 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
{"name": "time_tool", "event": "end"},
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_agent_invoke_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])

await agent.invoke_async("What is the time in New York?")
messages = json.dumps(agent.messages)

assert len(tool_events) == 0
assert "cancelled tool call" in messages
12 changes: 12 additions & 0 deletions tests_integ/tools/executors/test_sequential.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json

import pytest

Expand Down Expand Up @@ -59,3 +60,14 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
{"name": "weather_tool", "event": "end"},
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_agent_invoke_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])

await agent.invoke_async("What is the time in New York?")
messages = json.dumps(agent.messages)

assert len(tool_events) == 0
assert "cancelled tool call" in messages
Loading