Skip to content

Commit 7fce03e

Browse files
dbschmigelskipgrayy
authored andcommitted
security(tool_loader): prevent tool name and sys modules collisions in tool_loader (strands-agents#1214)
1 parent b357d0b commit 7fce03e

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

src/strands/tools/loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
logger = logging.getLogger(__name__)
1919

20+
_TOOL_MODULE_PREFIX = "_strands_tool_"
21+
2022

2123
def load_tool_from_string(tool_string: str) -> List[AgentTool]:
2224
"""Load tools follows strands supported input string formats.
@@ -65,7 +67,7 @@ def load_tools_from_file_path(tool_path: str) -> List[AgentTool]:
6567

6668
module = importlib.util.module_from_spec(spec)
6769
# Load, or re-load, the module
68-
sys.modules[module_name] = module
70+
sys.modules[f"{_TOOL_MODULE_PREFIX}{module_name}"] = module
6971
# Execute the module to run any top level code
7072
spec.loader.exec_module(module)
7173

@@ -200,7 +202,7 @@ def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]:
200202
raise ImportError(f"No loader available for {tool_name}")
201203

202204
module = importlib.util.module_from_spec(spec)
203-
sys.modules[tool_name] = module
205+
sys.modules[f"{_TOOL_MODULE_PREFIX}{tool_name}"] = module
204206
spec.loader.exec_module(module)
205207

206208
# Collect function-based tools decorated with @tool

tests/strands/tools/test_loader.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
22
import re
3+
import sys
34
import tempfile
45
import textwrap
56

67
import pytest
78

89
from strands.tools.decorator import DecoratedFunctionTool
9-
from strands.tools.loader import ToolLoader, load_tools_from_file_path
10+
from strands.tools.loader import _TOOL_MODULE_PREFIX, ToolLoader, load_tools_from_file_path
1011
from strands.tools.tools import PythonAgentTool
1112

1213

@@ -317,3 +318,29 @@ def test_load_tools_from_file_path_module_spec_missing():
317318
with tempfile.NamedTemporaryFile() as f:
318319
with pytest.raises(ImportError, match=f"Could not create spec for {os.path.basename(f.name)}"):
319320
load_tools_from_file_path(f.name)
321+
322+
323+
def test_tool_module_prefix_prevents_collision():
324+
"""Test that tool modules are loaded with prefix to prevent sys.modules collisions."""
325+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
326+
f.write(
327+
textwrap.dedent("""
328+
import strands
329+
330+
@strands.tools.tool
331+
def test_tool():
332+
return "test"
333+
""")
334+
)
335+
f.flush()
336+
337+
# Load the tool
338+
tools = load_tools_from_file_path(f.name)
339+
340+
# Check that module is in sys.modules with prefix
341+
module_name = os.path.basename(f.name).split(".")[0]
342+
prefixed_name = f"{_TOOL_MODULE_PREFIX}{module_name}"
343+
344+
assert prefixed_name in sys.modules
345+
assert len(tools) == 1
346+
assert tools[0].tool_name == "test_tool"

0 commit comments

Comments
 (0)