Skip to content

feat: add class based tool loader to enable easier integation with existing codebases #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
51be352
ci: add workflows verifying compatibility with tools, docs, and agent…
dbschmigelski May 19, 2025
d4eea13
Update verify-agent-builder-compatibility.yml to only run against 3.10
dbschmigelski May 19, 2025
3d2d5ff
Update verify-docs-compatibility.yml to only run against 3.10
dbschmigelski May 19, 2025
d1e3a7b
Update verify-tools-compatibility.yml to only run against 3.10
dbschmigelski May 19, 2025
552bd5f
Merge branch 'strands-agents:main' into dea/class-based-loader
dbschmigelski May 31, 2025
789530e
feat: add class based tool loader
XTEKiyZRLr8UsiQN May 31, 2025
d4023ad
rename test class and remove print statement
dbschmigelski Jun 2, 2025
48ef0f2
remove workflows from previous branch
dbschmigelski Jun 2, 2025
ecfd2e6
lint
dbschmigelski Jun 2, 2025
d096a88
fix remove duplicate .lower()
dbschmigelski Jun 2, 2025
0a1fa3c
remove duplicate detection as it is not needed
dbschmigelski Jun 2, 2025
c6ae2d0
refactor tool_name variable setting
dbschmigelski Jun 2, 2025
dbb4eca
edit doc string
dbschmigelski Jun 2, 2025
bb95920
remove unused import
dbschmigelski Jun 2, 2025
9056618
Merge branch 'strands-agents:main' into dea/class-based-loader
dbschmigelski Jun 12, 2025
4f6fbe9
fix: strongly type load_tools_from_instance to accept instance as an …
dbschmigelski Jun 12, 2025
4a0abc9
linting
dbschmigelski Jun 12, 2025
c2827d2
strongly type decroator extract_metadata
dbschmigelski Jun 12, 2025
cb02462
add type to class_loader methods variable
dbschmigelski Jun 12, 2025
57588b6
more linting
dbschmigelski Jun 12, 2025
46adafe
explicitly add func as a variable defined as Any
dbschmigelski Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
if results_truncated:
logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results)
return

# Try to trim index id when tool result cannot be truncated anymore
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size
Expand Down
2 changes: 2 additions & 0 deletions src/strands/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module provides the core functionality for creating, managing, and executing tools through agents.
"""

from .class_loader import load_tools_from_instance
from .decorator import tool
from .thread_pool_executor import ThreadPoolExecutorWrapper
from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec
Expand All @@ -15,4 +16,5 @@
"normalize_schema",
"normalize_tool_spec",
"ThreadPoolExecutorWrapper",
"load_tools_from_instance",
]
144 changes: 144 additions & 0 deletions src/strands/tools/class_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""This module defines a method for accessing tools from an instance.

It exposes:
- `load_tools_from_instance`: loads all public methods from an instance as AgentTool objects, with automatic name
disambiguation for instance methods.

It will load instance, class, and static methods from the class, including inherited methods.

By default, all public methods (not starting with _) will be loaded as AgentTool objects, even if not decorated.

Note:
Tool names must be unique within an agent. If you load tools from multiple instances of the same class,
you MUST provide a unique label for each instance, or tools will overwrite each other in the registry.
The registry does not warn or error on duplicates; the last tool registered with a given name wins.

The `load_tools_from_instance` function will return a list of `AgentTool` objects.
"""

import inspect
import logging
from typing import Any, Callable, List, Optional

from ..types.tools import AgentTool, ToolResult, ToolSpec, ToolUse
from .decorator import FunctionToolMetadata

logger = logging.getLogger(__name__)


class GenericFunctionTool(AgentTool):
"""Wraps any callable (instance, static, or class method) as an AgentTool.

Uses FunctionToolMetadata for metadata extraction and input validation.
"""

def __init__(self, func: Callable, name: Optional[str] = None, description: Optional[str] = None):
"""Initialize a GenericFunctionTool."""
super().__init__()
self._func = func
try:
self._meta = FunctionToolMetadata(func)
self._tool_spec = self._meta.extract_metadata()
if name:
self._tool_spec["name"] = name
if description:
self._tool_spec["description"] = description
except Exception as e:
logger.warning("Could not convert %s to AgentTool: %s", getattr(func, "__name__", str(func)), str(e))
raise

@property
def tool_name(self) -> str:
"""Return the tool's name."""
return str(self._tool_spec["name"])

@property
def tool_spec(self) -> ToolSpec:
"""Return the tool's specification."""
return self._tool_spec

@property
def tool_type(self) -> str:
"""Return the tool's type."""
return "function"

def invoke(self, tool: ToolUse, *args: Any, **kwargs: Any) -> ToolResult:
"""Invoke the tool with validated input."""
try:
validated_input = self._meta.validate_input(tool["input"])
result = self._func(**validated_input)
return {
"toolUseId": tool.get("toolUseId", "unknown"),
"status": "success",
"content": [{"text": str(result)}],
}
except Exception as e:
return {
"toolUseId": tool.get("toolUseId", "unknown"),
"status": "error",
"content": [{"text": f"Error: {e}"}],
}


def load_tools_from_instance(
instance: object,
disambiguator: Optional[str] = None,
) -> List[AgentTool]:
"""Load all public methods from an instance as AgentTool objects with name disambiguation.

Instance methods are bound to the given instance and are disambiguated by suffixing the tool name
with the given label (or the instance id if no prefix is provided). Static and class methods are
not disambiguated, as they do not depend on instance state.

Args:
instance: The instance to inspect.
disambiguator: Optional string to disambiguate instance method tool names. If not provided, uses id(instance).

Returns:
List of AgentTool objects (GenericFunctionTool wrappers).

Note:
Tool names must be unique within an agent. If you load tools from multiple instances of the same
class, you MUST provide a unique label for each instance, or tools will overwrite each
other in the registry. The registry does not warn or error on duplicates; the last tool registered
with a given name wins. This function will log a warning if a duplicate tool name is detected in
the returned list.

Example:
from strands.tools.class_loader import load_tools_from_instance

class MyClass:
def foo(self, x: int) -> int:
return x + 1

@staticmethod
def bar(y: int) -> int:
return y * 2

instance = MyClass()
tools = load_tools_from_instance(instance, disambiguator="special")
# tools is a list of AgentTool objects for foo and bar, with foo disambiguated as 'myclass_foo_special'
"""
methods: List[AgentTool] = []
class_name = instance.__class__.__name__.lower()
func: Any
for name, _member in inspect.getmembers(instance.__class__):
if name.startswith("_"):
continue
tool_name = f"{class_name}_{name}"
raw_attr = instance.__class__.__dict__.get(name, None)
if isinstance(raw_attr, staticmethod):
func = raw_attr.__func__
elif isinstance(raw_attr, classmethod):
func = raw_attr.__func__.__get__(instance, instance.__class__)
else:
# Instance method: bind to instance and disambiguate
func = getattr(instance, name, None)
tool_name += f"_{str(id(instance))}" if disambiguator is None else f"_{disambiguator}"
if callable(func):
try:
methods.append(GenericFunctionTool(func, name=tool_name))
except Exception:
# Warning already logged in GenericFunctionTool
pass
return methods
8 changes: 5 additions & 3 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
import docstring_parser
from pydantic import BaseModel, Field, create_model

from ..types.tools import ToolSpec

# Type for wrapped function
T = TypeVar("T", bound=Callable[..., Any])

Expand Down Expand Up @@ -124,7 +126,7 @@ def _create_input_model(self) -> Type[BaseModel]:
# Handle case with no parameters
return create_model(model_name)

def extract_metadata(self) -> Dict[str, Any]:
def extract_metadata(self) -> ToolSpec:
"""Extract metadata from the function to create a tool specification.

This method analyzes the function to create a standardized tool specification that Strands Agent can use to
Expand Down Expand Up @@ -155,7 +157,7 @@ def extract_metadata(self) -> Dict[str, Any]:
self._clean_pydantic_schema(input_schema)

# Create tool specification
tool_spec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}}
tool_spec: ToolSpec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}}

return tool_spec

Expand Down Expand Up @@ -288,7 +290,7 @@ def decorator(f: T) -> T:
tool_spec = tool_meta.extract_metadata()

# Update with any additional kwargs
tool_spec.update(tool_kwargs)
tool_spec.update(tool_kwargs) # type: ignore

# Attach TOOL_SPEC directly to the original function (critical for backward compatibility)
f.TOOL_SPEC = tool_spec # type: ignore
Expand Down
26 changes: 26 additions & 0 deletions tests-integ/test_class_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from strands.agent.agent import Agent
from strands.tools.class_loader import load_tools_from_instance


class WeatherTimeTool:
def get_weather_in_paris(self) -> str:
return "sunny"

@staticmethod
def get_time_in_paris(r) -> str:
return "15:00"


def test_agent_weather_and_time():
tool = WeatherTimeTool()
tools = load_tools_from_instance(tool)
prompt = (
"What is the time and weather in paris?"
"return only with the weather and time for example 'rainy 04:00'"
"if you cannot respond with 'FAILED'"
)
agent = Agent(tools=tools)
response = agent(prompt)
text = str(response).lower()
assert "sunny" in text
assert "15:00" in text
4 changes: 2 additions & 2 deletions tests-integ/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def test_can_reuse_mcp_client():


@pytest.mark.skipif(
condition=os.environ.get("GITHUB_ACTIONS") == 'true',
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue"
condition=os.environ.get("GITHUB_ACTIONS") == "true",
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
)
def test_streamable_http_mcp_client():
server_thread = threading.Thread(
Expand Down
131 changes: 131 additions & 0 deletions tests/strands/tools/test_class_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pytest

from strands.agent.agent import Agent
from strands.tools.class_loader import load_tools_from_instance


class MyTestClass:
def foo(self, x: int) -> int:
"""Add 1 to x."""
return x + 1

@staticmethod
def bar(y: int) -> int:
"""Multiply y by 2."""
return y * 2

@classmethod
def baz(cls, z: int) -> int:
"""Subtract 1 from z."""
return z - 1

not_a_method = 42


def test_agent_tool_invocation_for_all_method_types():
"""Test that agent.tool.{tool_name} works for instance, static, and class methods."""
instance = MyTestClass()
tools = load_tools_from_instance(instance, disambiguator="agenttest")
agent = Agent(tools=tools)
# Instance method
result_foo = agent.tool.mytestclass_foo_agenttest(x=5)
assert result_foo["status"] == "success"
assert result_foo["content"][0]["text"] == "6"
# Static method
result_bar = agent.tool.mytestclass_bar(y=3)
assert result_bar["status"] == "success"
assert result_bar["content"][0]["text"] == "6"
# Class method
result_baz = agent.tool.mytestclass_baz(z=10)
assert result_baz["status"] == "success"
assert result_baz["content"][0]["text"] == "9"


def test_non_callable_attributes_are_skipped():
"""Test that non-callable attributes are not loaded as tools."""

class ClassWithNonCallable:
foo = 123

def bar(self):
return 1

instance = ClassWithNonCallable()
tools = load_tools_from_instance(instance, disambiguator="nc")
tool_names = {tool.tool_name for tool in tools}
assert "classwithnoncallable_foo" not in tool_names
assert "classwithnoncallable_bar_nc" in tool_names


def test_error_handling_for_unconvertible_methods(monkeypatch):
"""Test that a warning is logged and method is skipped if it cannot be converted."""

class BadClass:
def bad(self, x):
return x

instance = BadClass()
# Patch FunctionToolMetadata to raise Exception
from strands.tools import class_loader

orig = class_loader.FunctionToolMetadata.__init__

def fail_init(self, func):
raise ValueError("fail")

monkeypatch.setattr(class_loader.FunctionToolMetadata, "__init__", fail_init)
with pytest.raises(ValueError):
# Direct instantiation should raise
class_loader.GenericFunctionTool(instance.bad)
# But loader should skip and not raise
tools = load_tools_from_instance(instance, disambiguator="bad")
assert tools == []
# Restore
monkeypatch.setattr(class_loader.FunctionToolMetadata, "__init__", orig)


def test_default_prefix_is_instance_id():
"""Test that the default prefix is id(instance) when no prefix is provided."""
instance = MyTestClass()
tools = load_tools_from_instance(instance)
tool_names = {tool.tool_name for tool in tools}
assert f"mytestclass_foo_{str(id(instance))}" in tool_names
assert "mytestclass_bar" in tool_names
assert "mytestclass_baz" in tool_names


def test_multiple_instances_of_same_class():
"""Test loading tools from multiple instances of the same class, including a static method."""

class Counter:
def __init__(self, start):
self.start = start

def increment(self, x: int) -> int:
return self.start + x

@staticmethod
def double_static(y: int) -> int:
return y * 2

a = Counter(10)
b = Counter(100)
tools_a = load_tools_from_instance(a, disambiguator="a")
tools_b = load_tools_from_instance(b, disambiguator="b")
agent = Agent(tools=tools_a + tools_b)
# Call increment for each instance
result_a = agent.tool.counter_increment_a(x=5)
result_b = agent.tool.counter_increment_b(x=5)
assert result_a["status"] == "success"
assert result_b["status"] == "success"
assert result_a["content"][0]["text"] == "15"
assert result_b["content"][0]["text"] == "105"
# Static method should be available (not disambiguated)
result_static = agent.tool.counter_double_static(y=7)
assert result_static["status"] == "success"
assert result_static["content"][0]["text"] == "14"
# Tool names are unique for instance methods, static method is shared
tool_names = set(agent.tool_names)
assert "counter_increment_a" in tool_names
assert "counter_increment_b" in tool_names
assert "counter_double_static" in tool_names
Loading