Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,5 @@ markers = [
"billable: marks test as billable (deselect with '-m \"not billable\"')",
]
addopts = "-m 'not billable' --ignore=src"
log_cli = true
log_cli_level = "DEBUG"
8 changes: 7 additions & 1 deletion src/rai/rai/agents/tool_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import json
import logging
import time
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

from langchain_core.messages import AIMessage, ToolCall, ToolMessage
Expand Down Expand Up @@ -65,9 +66,14 @@ def run_one(call: ToolCall):
artifact = None

try:
ts = time.perf_counter()
output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore
te = time.perf_counter() - ts
self.logger.info(
"Tool output (max 100 chars): " + str(output.content[0:100])
f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {str(output.content)[:100]}{'...' if len(str(output.content)) > 100 else ''}"
)
self.logger.debug(
f"Tool {call['name']} output: \n\n{str(output.content)}"
)
except ValidationError as e:
errors = e.errors()
Expand Down
82 changes: 82 additions & 0 deletions tests/core/test_tool_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.tools import tool

from rai.agents.tool_runner import ToolRunner
from rai.messages import HumanMultimodalMessage, ToolMultimodalMessage
from rai.messages.utils import preprocess_image
from rai.tools.ros.debugging import ros2_topic


@tool(response_format="content_and_artifact")
def get_image():
"""Get an image from the camera"""
return "Image retrieved", {
"images": [preprocess_image("docs/imgs/o3deSimulation.png")]
}


def test_tool_runner_invalid_call():
runner = ToolRunner(tools=[ros2_topic], logger=logging.getLogger(__name__))
tool_call = ToolCall(name="bad_fn", args={"command": "list"}, id="12345")
state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]}
output = runner.invoke(state)
assert isinstance(
output["messages"][0], AIMessage
), "First message is not an AIMessage"
assert isinstance(
output["messages"][1], ToolMessage
), "Tool output is not a tool message"
assert output["messages"][1].status == "error"


def test_tool_runner():
runner = ToolRunner(tools=[ros2_topic], logger=logging.getLogger(__name__))

tool_call = ToolCall(name="ros2_topic", args={"command": "list"}, id="12345")
state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]}
output = runner.invoke(state)
assert isinstance(
output["messages"][0], AIMessage
), "First message is not an AIMessage"
assert isinstance(
output["messages"][1], ToolMessage
), "Tool output is not a tool message"
assert (
len(output["messages"][-1].content) > 0
), "Tool output is empty. At least rosout should be visible."


def test_tool_runner_multimodal():
runner = ToolRunner(
tools=[ros2_topic, get_image], logger=logging.getLogger(__name__)
)

tool_call = ToolCall(name="get_image", args={}, id="12345")
state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]}
output = runner.invoke(state)

assert isinstance(
output["messages"][0], AIMessage
), "First message is not an AIMessage"
assert isinstance(
output["messages"][1], ToolMultimodalMessage
), "Tool output is not a multimodal message"
assert isinstance(
output["messages"][2], HumanMultimodalMessage
), "Human output is not a multimodal message"