1515import  threading 
1616import  time 
1717import  uuid 
18- from  typing  import  Any , Callable , Dict , List , Literal , Optional , Tuple 
18+ from  collections  import  OrderedDict 
19+ from  typing  import  Any , Callable , Dict , List , Literal , Optional , Tuple , Union , cast 
1920
21+ import  numpy  as  np 
2022import  rclpy 
2123import  rclpy .executors 
2224import  rclpy .node 
2325import  rclpy .time 
26+ import  rosidl_runtime_py .convert 
27+ from  cv_bridge  import  CvBridge 
28+ from  PIL  import  Image 
29+ from  pydub  import  AudioSegment 
2430from  rclpy .duration  import  Duration 
2531from  rclpy .executors  import  MultiThreadedExecutor 
2632from  rclpy .node  import  Node 
2733from  rclpy .qos  import  QoSProfile 
34+ from  sensor_msgs .msg  import  Image  as  ROS2Image 
2835from  tf2_ros  import  Buffer , LookupException , TransformListener , TransformStamped 
2936
37+ import  rai_interfaces .msg 
3038from  rai .communication  import  (
3139    ARIConnector ,
3240    ARIMessage ,
4149    ROS2TopicAPI ,
4250    TopicConfig ,
4351)
52+ from  rai_interfaces .msg  import  HRIMessage  as  ROS2HRIMessage_ 
53+ from  rai_interfaces .msg ._audio_message  import  (
54+     AudioMessage  as  ROS2HRIMessage__Audio ,
55+ )
4456
4557
4658class  ROS2ARIMessage (ARIMessage ):
@@ -200,26 +212,95 @@ class ROS2HRIMessage(HRIMessage):
200212    def  __init__ (self , payload : HRIPayload , message_author : Literal ["ai" , "human" ]):
201213        super ().__init__ (payload , message_author )
202214
215+     @classmethod  
216+     def  from_ros2 (
217+         cls , msg : rai_interfaces .msg .HRIMessage , message_author : Literal ["ai" , "human" ]
218+     ):
219+         cv_bridge  =  CvBridge ()
220+         images  =  [
221+             cv_bridge .imgmsg_to_cv2 (img_msg , "rgb8" )
222+             for  img_msg  in  cast (List [ROS2Image ], msg .images )
223+         ]
224+         pil_images  =  [Image .fromarray (img ) for  img  in  images ]
225+         audio_segments  =  [
226+             AudioSegment (
227+                 data = audio_msg .audio ,
228+                 frame_rate = audio_msg .sample_rate ,
229+                 sample_width = 2 ,  # bytes, int16 
230+                 channels = audio_msg .channels ,
231+             )
232+             for  audio_msg  in  msg .audios 
233+         ]
234+         return  ROS2HRIMessage (
235+             payload = HRIPayload (text = msg .text , images = pil_images , audios = audio_segments ),
236+             message_author = message_author ,
237+         )
238+ 
239+     def  to_ros2_dict (self ) ->  OrderedDict [str , Any ]:
240+         cv_bridge  =  CvBridge ()
241+         assert  isinstance (self .payload , HRIPayload )
242+         img_msgs  =  [
243+             cv_bridge .cv2_to_imgmsg (np .array (img ), "rgb8" )
244+             for  img  in  self .payload .images 
245+         ]
246+         audio_msgs  =  [
247+             ROS2HRIMessage__Audio (
248+                 audio = audio .raw_data ,
249+                 sample_rate = audio .frame_rate ,
250+                 channels = audio .channels ,
251+             )
252+             for  audio  in  self .payload .audios 
253+         ]
254+ 
255+         return  cast (
256+             OrderedDict [str , Any ],
257+             rosidl_runtime_py .convert .message_to_ordereddict (
258+                 ROS2HRIMessage_ (
259+                     text = self .payload .text ,
260+                     images = img_msgs ,
261+                     audios = audio_msgs ,
262+                 )
263+             ),
264+         )
265+ 
203266
204267class  ROS2HRIConnector (HRIConnector [ROS2HRIMessage ]):
205268    def  __init__ (
206269        self ,
207270        node_name : str  =  f"rai_ros2_hri_connector_{ str (uuid .uuid4 ())[- 12 :]}  ,
208-         targets : List [Tuple [str , TopicConfig ]] =  [],
209-         sources : List [Tuple [str , TopicConfig ]] =  [],
271+         targets : List [Union [ str ,  Tuple [str , TopicConfig ] ]] =  [],
272+         sources : List [Union [ str ,  Tuple [str , TopicConfig ] ]] =  [],
210273    ):
211-         configured_targets  =  [target [0 ] for  target  in  targets ]
212-         configured_sources  =  [source [0 ] for  source  in  sources ]
274+         configured_targets  =  [
275+             target [0 ] if  isinstance (target , tuple ) else  target  for  target  in  targets 
276+         ]
277+         configured_sources  =  [
278+             source [0 ] if  isinstance (source , tuple ) else  source  for  source  in  sources 
279+         ]
213280
214-         self ._configure_publishers (targets )
215-         self ._configure_subscribers (sources )
281+         _targets  =  [
282+             target 
283+             if  isinstance (target , tuple )
284+             else  (target , TopicConfig (is_subscriber = False ))
285+             for  target  in  targets 
286+         ]
287+         _sources  =  [
288+             source 
289+             if  isinstance (source , tuple )
290+             else  (source , TopicConfig (is_subscriber = True ))
291+             for  source  in  sources 
292+         ]
216293
217-         super ().__init__ (configured_targets , configured_sources )
218294        self ._node  =  Node (node_name )
219295        self ._topic_api  =  ConfigurableROS2TopicAPI (self ._node )
220296        self ._service_api  =  ROS2ServiceAPI (self ._node )
221297        self ._actions_api  =  ROS2ActionAPI (self ._node )
222298
299+         self ._configure_publishers (_targets )
300+         self ._configure_subscribers (_sources )
301+ 
302+         super ().__init__ (configured_targets , configured_sources )
303+ 
223304        self ._executor  =  MultiThreadedExecutor ()
224305        self ._executor .add_node (self ._node )
225306        self ._thread  =  threading .Thread (target = self ._executor .spin )
@@ -236,7 +317,7 @@ def _configure_subscribers(self, sources: List[Tuple[str, TopicConfig]]):
236317    def  send_message (self , message : ROS2HRIMessage , target : str , ** kwargs ):
237318        self ._topic_api .publish_configured (
238319            topic = target ,
239-             msg_content = message .payload ,
320+             msg_content = message .to_ros2_dict () ,
240321        )
241322
242323    def  receive_message (
@@ -249,16 +330,12 @@ def receive_message(
249330        auto_topic_type : bool  =  True ,
250331        ** kwargs : Any ,
251332    ) ->  ROS2HRIMessage :
252-         if  msg_type  !=  "std_msgs/msg/String" :
253-             raise  ValueError ("ROS2HRIConnector only supports receiving sting messages" )
254333        msg  =  self ._topic_api .receive (
255334            topic = source ,
256335            timeout_sec = timeout_sec ,
257-             msg_type = msg_type ,
258336            auto_topic_type = auto_topic_type ,
259337        )
260-         payload  =  HRIPayload (msg .data )
261-         return  ROS2HRIMessage (payload = payload , message_author = message_author )
338+         return  ROS2HRIMessage .from_ros2 (msg , message_author )
262339
263340    def  service_call (
264341        self , message : ROS2HRIMessage , target : str , timeout_sec : float , ** kwargs : Any 
@@ -284,3 +361,10 @@ def terminate_action(self, action_handle: str, **kwargs: Any):
284361        raise  NotImplementedError (
285362            f"{ self .__class__ .__name__ }  
286363        )
364+ 
365+     def  shutdown (self ):
366+         self ._executor .shutdown ()
367+         self ._thread .join ()
368+         self ._actions_api .shutdown ()
369+         self ._topic_api .shutdown ()
370+         self ._node .destroy_node ()
0 commit comments