Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def _filter_tools_with_request(self, request: AgentTurnCreateRequest) -> N
)
message = "\n".join([message.content for message in request.messages])
tools = [
dict(tool_name=tool.name, description=tool.description)
dict(tool_name=tool.tool_name, description=tool.description)
for tool in self.tool_defs
]
tools_filter_model_id = (
Expand Down Expand Up @@ -176,8 +176,8 @@ async def _filter_tools_with_request(self, request: AgentTurnCreateRequest) -> N
original_tools_count = len(self.tool_defs)
self.tool_defs = list(
filter(
lambda tool: tool.name in filtered_tools_names
or tool.name in always_included_tools,
lambda tool: tool.tool_name in filtered_tools_names
or tool.tool_name in always_included_tools,
self.tool_defs,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
)
from llama_stack.apis.agents import AgentConfig, AgentTurnCreateRequest, UserMessage
from llama_stack.apis.inference import ChatCompletionResponse, CompletionMessage
from llama_stack.apis.tools import ToolDef

from llama_stack.models.llama.datatypes import ToolDefinition

@pytest.fixture
def mock_inference_api(mocker):
Expand Down Expand Up @@ -54,7 +53,7 @@ def lightspeed_chat_agent(mock_inference_api, mock_storage, mocker):
async def test_create_and_execute_turn_with_filtering(lightspeed_chat_agent, mocker):
"""Test that create_and_execute_turn calls _filter_tools_with_request when enabled."""
lightspeed_chat_agent.tool_defs = [
ToolDef(name="test_tool", description="A test tool")
ToolDefinition(tool_name="test_tool", description="A test tool")
]
request = AgentTurnCreateRequest(
agent_id="test_agent",
Expand All @@ -74,7 +73,7 @@ async def test_create_and_execute_turn_without_filtering(lightspeed_chat_agent,
"""Test that create_and_execute_turn does not call _filter_tools_with_request when disabled."""
lightspeed_chat_agent.tools_filter_config.enabled = False
lightspeed_chat_agent.tool_defs = [
ToolDef(name="test_tool", description="A test tool")
ToolDefinition(tool_name="test_tool", description="A test tool")
]
request = AgentTurnCreateRequest(
agent_id="test_agent",
Expand All @@ -95,8 +94,8 @@ async def test_filter_tools_with_request(
):
"""Test the _filter_tools_with_request method."""
lightspeed_chat_agent.tool_defs = [
ToolDef(name="tool1", description="Tool 1"),
ToolDef(name="tool2", description="Tool 2"),
ToolDefinition(tool_name="tool1", description="Tool 1"),
ToolDefinition(tool_name="tool2", description="Tool 2"),
]
lightspeed_chat_agent.tool_name_to_args = {"tool1": {}, "tool2": {}}
mock_storage.get_session_turns.return_value = []
Expand All @@ -114,6 +113,6 @@ async def test_filter_tools_with_request(
await lightspeed_chat_agent._filter_tools_with_request(request)

assert len(lightspeed_chat_agent.tool_defs) == 1
assert lightspeed_chat_agent.tool_defs[0].name == "tool1"
assert lightspeed_chat_agent.tool_defs[0].tool_name == "tool1"
assert "tool1" in lightspeed_chat_agent.tool_name_to_args
assert "tool2" not in lightspeed_chat_agent.tool_name_to_args
Loading