Skip to content

Commit 22e9472

Browse files
authored
refactor: tools (#521)
1 parent 8821e30 commit 22e9472

File tree

31 files changed

+99
-727
lines changed

31 files changed

+99
-727
lines changed

examples/manipulation-demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from langchain_core.messages import HumanMessage
1919
from rai.agents.conversational_agent import create_conversational_agent
2020
from rai.communication.ros2.connectors import ROS2ARIConnector
21-
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
2221
from rai.tools.ros2 import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
22+
from rai.tools.ros2.manipulation import GetObjectPositionsTool, MoveToPointTool
2323
from rai.utils.model_initialization import get_llm_model
2424
from rai_open_set_vision.tools import GetGrabbingPointTool
2525

examples/rosbot-xl-demo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
from rai.agents import ReActAgent
2121
from rai.communication.ros2 import ROS2ARIConnector
2222
from rai.frontend.streamlit import run_streamlit_app
23-
from rai.tools.ros.manipulation import GetGrabbingPointTool, GetObjectPositionsTool
2423
from rai.tools.ros2 import (
24+
GetObjectPositionsTool,
2525
GetROS2ImageConfiguredTool,
2626
GetROS2TransformConfiguredTool,
2727
Nav2Toolkit,
2828
)
2929
from rai.tools.time import WaitForSecondsTool
3030
from rai.utils.model_initialization import get_llm_model
31+
from rai_open_set_vision.tools import GetGrabbingPointTool
3132

3233
# Set page configuration first
3334
st.set_page_config(

src/rai_bench/rai_bench/examples/o3de_test_benchmark.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@
2323
from langchain.tools import BaseTool
2424
from rai.agents.conversational_agent import create_conversational_agent
2525
from rai.communication.ros2.connectors import ROS2ARIConnector
26-
from rai.tools.ros.manipulation import (
27-
GetObjectPositionsTool,
28-
MoveToPointTool,
29-
)
3026
from rai.tools.ros2 import (
27+
GetObjectPositionsTool,
3128
GetROS2ImageTool,
3229
GetROS2TopicsNamesAndTypesTool,
30+
MoveToPointTool,
3331
)
3432
from rai.utils.model_initialization import get_llm_model
3533
from rai_open_set_vision.tools import GetGrabbingPointTool

src/rai_bench/rai_bench/tool_calling_agent_bench/mocked_tools.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,14 @@
2020
from rai.communication.ros2.connectors import ROS2ARIConnector
2121
from rai.communication.ros2.messages import ROS2ARIMessage
2222
from rai.messages import MultimodalArtifact, preprocess_image
23-
from rai.tools.ros.manipulation import (
24-
GetGrabbingPointTool,
25-
GetObjectPositionsTool,
26-
MoveToPointTool,
27-
)
2823
from rai.tools.ros2 import (
24+
GetObjectPositionsTool,
2925
GetROS2ImageTool,
3026
GetROS2TopicsNamesAndTypesTool,
27+
MoveToPointTool,
3128
ReceiveROS2MessageTool,
3229
)
30+
from rai_open_set_vision.tools import GetGrabbingPointTool
3331

3432

3533
class MockGetROS2TopicsNamesAndTypesTool(GetROS2TopicsNamesAndTypesTool):

src/rai_bench/rai_bench/tool_calling_agent_bench/ros2_agent_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from langchain_core.messages import AIMessage
2222
from langchain_core.messages.tool import ToolCall
2323
from langchain_core.tools import BaseTool
24-
from rai.tools.ros.manipulation import MoveToPointToolInput
24+
from rai.tools.ros2 import MoveToPointToolInput
2525

2626
from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
2727
ROS2ToolCallingAgentTask,

src/rai_core/rai/communication/ros2/api/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414

1515
from .action import ROS2ActionAPI
1616
from .base import IROS2Message
17+
from .conversion import (
18+
convert_ros_img_to_base64,
19+
convert_ros_img_to_cv2mat,
20+
convert_ros_img_to_ndarray,
21+
import_message_from_str,
22+
ros2_message_to_dict,
23+
)
1724
from .service import ROS2ServiceAPI
1825
from .topic import ConfigurableROS2TopicAPI, ROS2TopicAPI, TopicConfig
1926

@@ -24,4 +31,9 @@
2431
"ROS2ServiceAPI",
2532
"ROS2TopicAPI",
2633
"TopicConfig",
34+
"convert_ros_img_to_base64",
35+
"convert_ros_img_to_cv2mat",
36+
"convert_ros_img_to_ndarray",
37+
"import_message_from_str",
38+
"ros2_message_to_dict",
2739
]

src/rai_core/rai/communication/ros2/api/action.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
BaseROS2API,
6262
IROS2Message,
6363
)
64-
from rai.tools.ros.utils import import_message_from_str
64+
from rai.communication.ros2.api.conversion import import_message_from_str
6565

6666

6767
class ROS2ActionData(TypedDict):

src/rai_core/rai/communication/ros2/api/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from rclpy.topic_endpoint_info import TopicEndpointInfo
4444

45-
from rai.tools.ros.utils import import_message_from_str
45+
from rai.communication.ros2.api.conversion import import_message_from_str
4646

4747

4848
@runtime_checkable

src/rai_core/rai/tools/ros/utils.py renamed to src/rai_core/rai/communication/ros2/api/conversion.py

Lines changed: 22 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,37 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import base64
17-
from typing import Optional, Type, Union, cast
16+
from typing import Any, OrderedDict, Type, cast
1817

1918
import cv2
2019
import numpy as np
21-
import rclpy
22-
import rclpy.executors
23-
import rclpy.node
24-
import rclpy.time
20+
import rosidl_runtime_py.convert
21+
import rosidl_runtime_py.set_message
22+
import rosidl_runtime_py.utilities
2523
import sensor_msgs.msg
2624
from cv_bridge import CvBridge
27-
from rclpy.duration import Duration
28-
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
29-
from rclpy.node import Node
30-
from rclpy.qos import QoSProfile
31-
from rclpy.signals import SignalHandlerGuardCondition
32-
from rclpy.utilities import timeout_sec_to_nsec
3325
from rosidl_parser.definition import NamespacedType
3426
from rosidl_runtime_py.import_message import import_message_from_namespaced_type
3527
from rosidl_runtime_py.utilities import get_namespaced_type
36-
from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped
28+
29+
30+
def ros2_message_to_dict(message: Any) -> OrderedDict[str, Any]:
31+
"""Convert any ROS2 message into a dictionary.
32+
33+
Args:
34+
message: A ROS2 message instance
35+
36+
Returns:
37+
A dictionary representation of the message
38+
39+
Raises:
40+
TypeError: If the input is not a valid ROS2 message
41+
"""
42+
msg_dict: OrderedDict[str, Any] = rosidl_runtime_py.convert.message_to_ordereddict(
43+
message
44+
) # type: ignore
45+
return msg_dict
3746

3847

3948
def import_message_from_str(msg_type: str) -> Type[object]:
@@ -101,81 +110,3 @@ def convert_ros_img_to_base64(msg: sensor_msgs.msg.Image) -> str:
101110
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
102111
image_data = cv2.imencode(".png", cv_image)[1].tostring() # type: ignore
103112
return base64.b64encode(image_data).decode("utf-8") # type: ignore
104-
105-
106-
# Copied from https://github.com/ros2/rclpy/blob/jazzy/rclpy/rclpy/wait_for_message.py, to support humble
107-
def wait_for_message(
108-
msg_type: Type[object],
109-
node: "Node",
110-
topic: str,
111-
*,
112-
qos_profile: Union[QoSProfile, int] = 1,
113-
time_to_wait: float = -1,
114-
) -> tuple[bool, Optional[object]]:
115-
"""
116-
Wait for the next incoming message.
117-
118-
:param msg_type: message type
119-
:param node: node to initialize the subscription on
120-
:param topic: topic name to wait for message
121-
:param qos_profile: QoS profile to use for the subscription
122-
:param time_to_wait: seconds to wait before returning
123-
:returns: (True, msg) if a message was successfully received, (False, None) if message
124-
could not be obtained or shutdown was triggered asynchronously on the context.
125-
"""
126-
context = node.context
127-
wait_set = _rclpy.WaitSet(1, 1, 0, 0, 0, 0, context.handle)
128-
wait_set.clear_entities()
129-
130-
sub = node.create_subscription(
131-
msg_type, topic, lambda _: None, qos_profile=qos_profile
132-
)
133-
try:
134-
wait_set.add_subscription(sub.handle)
135-
sigint_gc = SignalHandlerGuardCondition(context=context)
136-
wait_set.add_guard_condition(sigint_gc.handle)
137-
138-
timeout_nsec = timeout_sec_to_nsec(time_to_wait)
139-
wait_set.wait(timeout_nsec)
140-
141-
subs_ready = wait_set.get_ready_entities("subscription")
142-
guards_ready = wait_set.get_ready_entities("guard_condition")
143-
144-
if guards_ready:
145-
if sigint_gc.handle.pointer in guards_ready:
146-
return False, None
147-
148-
if subs_ready:
149-
if sub.handle.pointer in subs_ready:
150-
msg_info = sub.handle.take_message(sub.msg_type, sub.raw)
151-
if msg_info is not None:
152-
return True, msg_info[0]
153-
finally:
154-
# TODO(boczekbartek): uncomment when rclpy resolves: https://github.com/ros2/rclpy/issues/1142
155-
# node.destroy_subscription(sub)
156-
pass
157-
158-
return False, None
159-
160-
161-
def get_transform(
162-
node: rclpy.node.Node,
163-
target_frame: str,
164-
source_frame: str,
165-
timeout_sec: float = 5.0,
166-
) -> TransformStamped:
167-
tf_buffer = Buffer(node=node)
168-
tf_listener = TransformListener(tf_buffer, node)
169-
170-
transform: Optional[TransformStamped] = tf_buffer.lookup_transform(
171-
target_frame, source_frame, rclpy.time.Time(), timeout=Duration(seconds=3)
172-
)
173-
174-
tf_listener.unregister()
175-
176-
if transform is None:
177-
raise LookupException(
178-
f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds"
179-
)
180-
181-
return transform

src/rai_core/rai/communication/ros2/api/service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from rai.communication.ros2.api.base import (
3535
BaseROS2API,
3636
)
37-
from rai.tools.ros.utils import import_message_from_str
37+
from rai.communication.ros2.api.conversion import import_message_from_str
3838

3939

4040
class ROS2ServiceAPI(BaseROS2API):

0 commit comments

Comments
 (0)