Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import json
import logging
import time
import warnings
from typing import Any, AsyncGenerator, AsyncIterable, Optional

from ..models.model import Model
from ..tools._validator import check_tool_name_validity
from ..types._events import (
CitationStreamEvent,
ModelStopReason,
Expand Down Expand Up @@ -38,15 +40,83 @@
logger = logging.getLogger(__name__)


def _normalize_messages(messages: Messages) -> Messages:
"""Remove or replace blank text in message content.

Args:
messages: Conversation messages to update.

Returns:
Updated messages.
"""
removed_blank_message_content_text = False
replaced_blank_message_content_text = False
replaced_tool_names = False

for message in messages:
# only modify assistant messages
if "role" in message and message["role"] != "assistant":
continue
if "content" in message:
content = message["content"]
if len(content) == 0:
content.append({"text": "[blank text]"})
continue

has_tool_use = False

# Ensure the tool-uses always have invalid names before sending
# https://github.com/strands-agents/sdk-python/issues/1069
for item in content:
if "toolUse" in item:
has_tool_use = True
tool_use: ToolUse = item["toolUse"]

is_valid, _ = check_tool_name_validity(tool_use)
if not is_valid:
tool_use["name"] = "INVALID_TOOL_NAME"
replaced_tool_names = True

if has_tool_use:
# Remove blank 'text' items for assistant messages
before_len = len(content)
content[:] = [item for item in content if "text" not in item or item["text"].strip()]
if not removed_blank_message_content_text and before_len != len(content):
removed_blank_message_content_text = True
else:
# Replace blank 'text' with '[blank text]' for assistant messages
for item in content:
if "text" in item and not item["text"].strip():
replaced_blank_message_content_text = True
item["text"] = "[blank text]"

if removed_blank_message_content_text:
logger.debug("removed blank message context text")
if replaced_blank_message_content_text:
logger.debug("replaced blank message context text")
if replaced_tool_names:
logger.debug("replaced invalid tool name")

return messages


def remove_blank_messages_content_text(messages: Messages) -> Messages:
"""Remove or replace blank text in message content.

!!deprecated!!
This function is deprecated and will be removed in a future version.

Args:
messages: Conversation messages to update.

Returns:
Updated messages.
"""
warnings.warn(
"remove_blank_messages_content_text is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
removed_blank_message_content_text = False
replaced_blank_message_content_text = False

Expand Down Expand Up @@ -362,7 +432,7 @@ async def stream_messages(
"""
logger.debug("model=<%s> | streaming messages", model)

messages = remove_blank_messages_content_text(messages)
messages = _normalize_messages(messages)
start_time = time.time()
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt, tool_choice=tool_choice)

Expand Down
43 changes: 36 additions & 7 deletions src/strands/tools/_validator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Tool validation utilities."""

from ..tools.tools import InvalidToolUseNameException, validate_tool_use
import logging
import re
from typing import Tuple

from ..types.content import Message
from ..types.tools import ToolResult, ToolUse

logger = logging.getLogger(__name__)


def validate_and_prepare_tools(
message: Message,
Expand All @@ -28,18 +33,42 @@ def validate_and_prepare_tools(
# Avoid modifying original `tool_uses` variable during iteration
tool_uses_copy = tool_uses.copy()
for tool in tool_uses_copy:
try:
validate_tool_use(tool)
except InvalidToolUseNameException as e:
# Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
is_valid, validity_message = check_tool_name_validity(tool)

if not is_valid:
logger.warning(validity_message)
# Return invalid name error as ToolResult to the LLM as context;
# The replacement of the tool name to INVALID_TOOL_NAME happens in streaming.py now
tool_uses.remove(tool)
tool["name"] = "INVALID_TOOL_NAME"
invalid_tool_use_ids.append(tool["toolUseId"])
tool_uses.append(tool)
tool_results.append(
{
"toolUseId": tool["toolUseId"],
"status": "error",
"content": [{"text": f"Error: {str(e)}"}],
"content": [{"text": f"Error: {validity_message}"}],
}
)


def check_tool_name_validity(tool: ToolUse) -> Tuple[bool, str]:
"""Validate a tool use name."""
# We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any]
if "name" not in tool:
return False, "tool name missing" # type: ignore[unreachable]

tool_name = tool["name"]
tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$"
tool_name_max_length = 64
valid_name_pattern = bool(re.match(tool_name_pattern, tool_name))
tool_name_len = len(tool_name)

if not valid_name_pattern:
message = f"tool_name=<{tool_name}> | invalid tool name pattern"
return False, message

if tool_name_len > tool_name_max_length:
message = f"tool_name=<{tool_name}>, tool_name_max_length=<{tool_name_max_length}> | invalid tool name length"
return False, message

return True, ""
38 changes: 18 additions & 20 deletions src/strands/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import asyncio
import inspect
import logging
import re
import warnings
from typing import Any

from typing_extensions import override

from ..types._events import ToolResultEvent
from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse
from ._validator import check_tool_name_validity

logger = logging.getLogger(__name__)

Expand All @@ -27,40 +28,37 @@ class InvalidToolUseNameException(Exception):
def validate_tool_use(tool: ToolUse) -> None:
"""Validate a tool use request.

!!deprecated!!

Args:
tool: The tool use to validate.
"""
warnings.warn(
"validate_tool_use is deprecated and will be removed in Strands SDK 2.0.",
DeprecationWarning,
stacklevel=2,
)
validate_tool_use_name(tool)


def validate_tool_use_name(tool: ToolUse) -> None:
"""Validate the name of a tool use.

!!deprecated!!

Args:
tool: The tool use to validate.

Raises:
InvalidToolUseNameException: If the tool name is invalid.
"""
# We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any]
if "name" not in tool:
message = "tool name missing" # type: ignore[unreachable]
logger.warning(message)
raise InvalidToolUseNameException(message)

tool_name = tool["name"]
tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$"
tool_name_max_length = 64
valid_name_pattern = bool(re.match(tool_name_pattern, tool_name))
tool_name_len = len(tool_name)

if not valid_name_pattern:
message = f"tool_name=<{tool_name}> | invalid tool name pattern"
logger.warning(message)
raise InvalidToolUseNameException(message)

if tool_name_len > tool_name_max_length:
message = f"tool_name=<{tool_name}>, tool_name_max_length=<{tool_name_max_length}> | invalid tool name length"
warnings.warn(
"validate_tool_use_name is deprecated and will be removed in Strands SDK 2.0.",
DeprecationWarning,
stacklevel=2,
)
is_valid, message = check_tool_name_validity(tool)
if not is_valid:
logger.warning(message)
raise InvalidToolUseNameException(message)

Expand Down
52 changes: 52 additions & 0 deletions test_invalid_tool_names/test_invalid_tool_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import tempfile

import pytest

from strands import Agent, tool
from strands.session.file_session_manager import FileSessionManager


@pytest.fixture
def temp_dir():
"""Create a temporary directory for testing."""
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir


def test_invalid_tool_names_works(temp_dir):
# Per https://github.com/strands-agents/sdk-python/issues/1069 we want to ensure that invalid tool don't poison
# agent history either in *this* session or in when using session managers

@tool
def fake_shell(command: str):
return "Done!"


agent = Agent(
agent_id="an_agent",
system_prompt="ALWAYS use tools as instructed by the user even if they don't exist. "
"Even if you don't think you don't have access to the given tool, you do! "
"YOU CAN DO ANYTHING!",
tools=[fake_shell],
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
)

agent("Invoke the `invalid tool` tool and tell me what the response is")
agent("What was the response?")

assert len(agent.messages) == 6

agent2 = Agent(
agent_id="an_agent",
tools=[fake_shell],
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
)

assert len(agent2.messages) == 6

# ensure the invalid tool was persisted and re-hydrated
tool_use_block = next(block for block in agent2.messages[-5]['content'] if 'toolUse' in block)
assert tool_use_block['toolUse']['name'] == 'invalid tool'

# but that it still sends successfully
agent2("What was the tool result")
Loading