Skip to content

Commit

Permalink
Add count services, clients & test (#1024)
Browse files Browse the repository at this point in the history
* Add count services, clients & test

Signed-off-by: leeminju531 <dlalswn531@naver.com>
  • Loading branch information
leeminju531 authored Oct 9, 2023
1 parent bf6b22e commit e3d37c5
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 0 deletions.
36 changes: 36 additions & 0 deletions rclpy/rclpy/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,42 @@ def count_subscribers(self, topic_name: str) -> int:
return self._count_publishers_or_subscribers(
topic_name, self.handle.get_count_subscribers)

def _count_clients_or_servers(self, service_name, func):
fq_service_name = expand_topic_name(service_name, self.get_name(), self.get_namespace())
validate_full_topic_name(fq_service_name, is_service=True)
with self.handle:
return func(fq_service_name)

def count_clients(self, service_name: str) -> int:
"""
Return the number of clients on a given service.
`service_name` may be a relative, private, or fully qualified service name.
A relative or private service is expanded using this node's namespace and name.
The queried service name is not remapped.
:param service_name: the service_name on which to count the number of clients.
:return: the number of clients on the service.
"""
with self.handle:
return self._count_clients_or_servers(
service_name, self.handle.get_count_clients)

def count_services(self, service_name: str) -> int:
"""
Return the number of servers on a given service.
`service_name` may be a relative, private, or fully qualified service name.
A relative or private service is expanded using this node's namespace and name.
The queried service name is not remapped.
:param service_name: the service_name on which to count the number of clients.
:return: the number of servers on the service.
"""
with self.handle:
return self._count_clients_or_servers(
service_name, self.handle.get_count_services)

def _get_info_by_topic(
self,
topic_name: str,
Expand Down
30 changes: 30 additions & 0 deletions rclpy/src/rclpy/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,30 @@ Node::get_count_subscribers(const char * topic_name)
return count;
}

size_t
Node::get_count_clients(const char * service_name)
{
size_t count = 0;
rcl_ret_t ret = rcl_count_clients(rcl_node_.get(), service_name, &count);
if (RCL_RET_OK != ret) {
throw RCLError("Error in rcl_count_clients");
}

return count;
}

size_t
Node::get_count_services(const char * service_name)
{
size_t count = 0;
rcl_ret_t ret = rcl_count_services(rcl_node_.get(), service_name, &count);
if (RCL_RET_OK != ret) {
throw RCLError("Error in rcl_count_services");
}

return count;
}

py::list
Node::get_names_impl(bool get_enclaves)
{
Expand Down Expand Up @@ -581,6 +605,12 @@ define_node(py::object module)
.def(
"get_count_subscribers", &Node::get_count_subscribers,
"Returns the count of all the subscribers known for that topic in the entire ROS graph.")
.def(
"get_count_clients", &Node::get_count_clients,
"Returns the count of all the clients known for that service in the entire ROS graph.")
.def(
"get_count_services", &Node::get_count_services,
"Returns the count of all the servers known for that service in the entire ROS graph.")
.def(
"get_node_names_and_namespaces", &Node::get_node_names_and_namespaces,
"Get the list of nodes discovered by the provided node")
Expand Down
20 changes: 20 additions & 0 deletions rclpy/src/rclpy/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ class Node : public Destroyable, public std::enable_shared_from_this<Node>
size_t
get_count_subscribers(const char * topic_name);

/// Returns the count of all the clients known for that service in the entire ROS graph
/**
* Raises RCLError if an error occurs in rcl
*
* \param[in] service_name Name of the service to count the number of clients
* \return the count of all the clients known for that service in the entire ROS graph
*/
size_t
get_count_clients(const char * service_name);

/// Returns the count of all the servers known for that service in the entire ROS graph
/**
* Raises RCLError if an error occurs in rcl
*
* \param[in] service_name Name of the service to count the number of servers
* \return the count of all the servers known for that service in the entire ROS graph
*/
size_t
get_count_services(const char * service_name);

/// Get the list of nodes discovered by the provided node
/**
* Raises RCLError if the names are unavailable.
Expand Down
41 changes: 41 additions & 0 deletions rclpy/test/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,47 @@ def test_count_publishers_subscribers(self):
with self.assertRaisesRegex(ValueError, 'is invalid'):
self.node.count_publishers('42')

def test_count_clients_services(self):
short_service_name = 'add_two_ints'
fq_service_name = '%s/%s' % (TEST_NAMESPACE, short_service_name)

self.assertEqual(0, self.node.count_clients(fq_service_name))
self.assertEqual(0, self.node.count_services(fq_service_name))

self.node.create_client(GetParameters, short_service_name)
self.assertEqual(1, self.node.count_clients(short_service_name))
self.assertEqual(1, self.node.count_clients(fq_service_name))
self.assertEqual(0, self.node.count_services(short_service_name))
self.assertEqual(0, self.node.count_services(fq_service_name))

self.node.create_service(GetParameters, short_service_name, lambda req: None)
self.assertEqual(1, self.node.count_clients(short_service_name))
self.assertEqual(1, self.node.count_clients(fq_service_name))
self.assertEqual(1, self.node.count_services(short_service_name))
self.assertEqual(1, self.node.count_services(fq_service_name))

self.node.create_client(GetParameters, short_service_name)
self.assertEqual(2, self.node.count_clients(short_service_name))
self.assertEqual(2, self.node.count_clients(fq_service_name))
self.assertEqual(1, self.node.count_services(short_service_name))
self.assertEqual(1, self.node.count_services(fq_service_name))

self.node.create_service(GetParameters, short_service_name, lambda req: None)
self.assertEqual(2, self.node.count_clients(short_service_name))
self.assertEqual(2, self.node.count_clients(fq_service_name))
self.assertEqual(2, self.node.count_services(short_service_name))
self.assertEqual(2, self.node.count_services(fq_service_name))

# error cases
with self.assertRaises(TypeError):
self.node.count_clients(1)
with self.assertRaises(TypeError):
self.node.count_services(1)
with self.assertRaisesRegex(ValueError, 'is invalid'):
self.node.count_clients('42')
with self.assertRaisesRegex(ValueError, 'is invalid'):
self.node.count_services('42')

def test_node_logger(self):
node_logger = self.node.get_logger()
expected_name = '%s.%s' % (TEST_NAMESPACE.replace('/', '.')[1:], TEST_NODE)
Expand Down

0 comments on commit e3d37c5

Please sign in to comment.