Skip to content

iterative tool handler process #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def caller(
}

# Execute the tool
tool_result = self._agent.tool_handler.process(
events = self._agent.tool_handler.process(
tool=tool_use,
model=self._agent.model,
system_prompt=self._agent.system_prompt,
Expand All @@ -137,6 +137,12 @@ def caller(
kwargs=kwargs,
)

try:
while True:
next(events)
except StopIteration as stop:
tool_result = cast(ToolResult, stop.value)

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
Expand Down
13 changes: 9 additions & 4 deletions src/strands/handlers/tool_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..tools.registry import ToolRegistry
from ..types.content import Messages
from ..types.models import Model
from ..types.tools import ToolConfig, ToolHandler, ToolUse
from ..types.tools import ToolConfig, ToolGenerator, ToolHandler, ToolUse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,7 +35,7 @@ def process(
messages: Messages,
tool_config: ToolConfig,
kwargs: dict[str, Any],
) -> Any:
) -> ToolGenerator:
"""Process a tool invocation.

Looks up the tool in the registry and invokes it with the provided parameters.
Expand All @@ -48,8 +48,11 @@ def process(
tool_config: Configuration for the tool.
kwargs: Additional keyword arguments passed to the tool.

Yields:
Events of the tool invocation.

Returns:
The result of the tool invocation, or an error response if the tool fails or is not found.
The final tool result or an error response if the tool fails or is not found.
"""
logger.debug("tool=<%s> | invoking", tool)
tool_use_id = tool["toolUseId"]
Expand Down Expand Up @@ -82,7 +85,9 @@ def process(
}
)

return tool_func.invoke(tool, **kwargs)
result = tool_func.invoke(tool, **kwargs)
yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from
return result

except Exception as e:
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
Expand Down
17 changes: 8 additions & 9 deletions src/strands/tools/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from ..tools.tools import InvalidToolUseNameException, validate_tool_use
from ..types.content import Message
from ..types.event_loop import ParallelToolExecutorInterface
from ..types.tools import ToolResult, ToolUse
from ..types.tools import ToolGenerator, ToolResult, ToolUse

logger = logging.getLogger(__name__)


def run_tools(
handler: Callable[[ToolUse], ToolResult],
handler: Callable[[ToolUse], Generator[dict[str, Any], None, ToolResult]],
tool_uses: list[ToolUse],
event_loop_metrics: EventLoopMetrics,
invalid_tool_use_ids: list[str],
Expand All @@ -44,16 +44,15 @@ def run_tools(
Events of the tool invocations. Tool results are appended to `tool_results`.
"""

def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]:
def handle(tool: ToolUse) -> ToolGenerator:
tracer = get_tracer()
tool_call_span = tracer.start_tool_call_span(tool, parent_span)

tool_name = tool["name"]
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
tool_start_time = time.time()

result = handler(tool)
yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from
result = yield from handler(tool)

tool_success = result.get("status") == "success"
tool_duration = time.time() - tool_start_time
Expand All @@ -74,14 +73,14 @@ def work(
) -> ToolResult:
events = handle(tool)

while True:
try:
try:
while True:
event = next(events)
worker_queue.put((worker_id, event))
worker_event.wait()

except StopIteration as stop:
return cast(ToolResult, stop.value)
except StopIteration as stop:
return cast(ToolResult, stop.value)

tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]

Expand Down
23 changes: 15 additions & 8 deletions src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Union

from typing_extensions import TypedDict

Expand Down Expand Up @@ -90,7 +90,7 @@ class ToolResult(TypedDict):
toolUseId: The unique identifier of the tool use request that produced this result.
"""

content: List[ToolResultContent]
content: list[ToolResultContent]
status: ToolResultStatus
toolUseId: str

Expand Down Expand Up @@ -122,9 +122,9 @@ class ToolChoiceTool(TypedDict):


ToolChoice = Union[
Dict[Literal["auto"], ToolChoiceAuto],
Dict[Literal["any"], ToolChoiceAny],
Dict[Literal["tool"], ToolChoiceTool],
dict[Literal["auto"], ToolChoiceAuto],
dict[Literal["any"], ToolChoiceAny],
dict[Literal["tool"], ToolChoiceTool],
]
"""
Configuration for how the model should choose tools.
Expand All @@ -135,6 +135,10 @@ class ToolChoiceTool(TypedDict):
"""


ToolGenerator = Generator[dict[str, Any], None, ToolResult]
"""Generator of tool events and a returned tool result."""


class ToolConfig(TypedDict):
"""Configuration for tools in a model request.

Expand All @@ -143,7 +147,7 @@ class ToolConfig(TypedDict):
toolChoice: Configuration for how the model should choose tools.
"""

tools: List[Tool]
tools: list[Tool]
toolChoice: ToolChoice


Expand Down Expand Up @@ -250,7 +254,7 @@ def process(
messages: "Messages",
tool_config: ToolConfig,
kwargs: dict[str, Any],
) -> ToolResult:
) -> ToolGenerator:
"""Process a tool use request and execute the tool.

Args:
Expand All @@ -261,7 +265,10 @@ def process(
tool_config: The tool configuration for the current session.
kwargs: Additional context-specific arguments.

Yields:
Events of the tool invocation.

Returns:
The result of the tool execution.
The final tool result.
"""
...
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,22 @@ def boto3_profile_path(boto3_profile, tmp_path, monkeypatch):
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(path))

return path


## Itertools


@pytest.fixture(scope="session")
def generate():
def generate(generator):
events = []

try:
while True:
event = next(generator)
events.append(event)

except StopIteration as stop:
return events, stop.value

return generate
6 changes: 2 additions & 4 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def test_agent_init_with_no_model_or_model_id():


def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()
agent.tool_handler = unittest.mock.Mock(process=unittest.mock.Mock(return_value=iter([])))

@strands.tools.tool(name="system_prompter")
def function(system_prompt: str) -> str:
Expand All @@ -798,7 +798,7 @@ def function(system_prompt: str) -> str:


def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()
agent.tool_handler = unittest.mock.Mock(process=unittest.mock.Mock(return_value=iter([])))

tool_name = "system-prompter"

Expand Down Expand Up @@ -826,8 +826,6 @@ def function(system_prompt: str) -> str:


def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()

mock_randint.return_value = 1

with pytest.raises(AttributeError) as err:
Expand Down
12 changes: 8 additions & 4 deletions tests/strands/handlers/test_tool_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,33 @@ def identity(a: int) -> int:
return {"toolUseId": "identity", "name": "identity", "input": {"a": 1}}


def test_process(tool_handler, tool_use_identity):
tru_result = tool_handler.process(
def test_process(tool_handler, tool_use_identity, generate):
process = tool_handler.process(
tool_use_identity,
model=unittest.mock.Mock(),
system_prompt="p1",
messages=[],
tool_config={},
kwargs={},
)

_, tru_result = generate(process)
exp_result = {"toolUseId": "identity", "status": "success", "content": [{"text": "1"}]}

assert tru_result == exp_result


def test_process_missing_tool(tool_handler):
tru_result = tool_handler.process(
def test_process_missing_tool(tool_handler, generate):
process = tool_handler.process(
tool={"toolUseId": "missing", "name": "missing", "input": {}},
model=unittest.mock.Mock(),
system_prompt="p1",
messages=[],
tool_config={},
kwargs={},
)

_, tru_result = generate(process)
exp_result = {
"toolUseId": "missing",
"status": "error",
Expand Down
7 changes: 5 additions & 2 deletions tests/strands/tools/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def moto_autouse(moto_env):
@pytest.fixture
def tool_handler(request):
def handler(tool_use):
yield {"event": "abc"}
return {
**params,
"toolUseId": tool_use["toolUseId"],
Expand Down Expand Up @@ -102,7 +103,9 @@ def test_run_tools(
cycle_trace,
parallel_tool_executor,
)
list(stream)

tru_events = list(stream)
exp_events = [{"event": "abc"}]

tru_results = tool_results
exp_results = [
Expand All @@ -117,7 +120,7 @@ def test_run_tools(
},
]

assert tru_results == exp_results
assert tru_events == exp_events and tru_results == exp_results


@pytest.mark.parametrize("invalid_tool_use_ids", [["t1"]], indirect=True)
Expand Down
Loading