Skip to content

Commit

Permalink
(fix) Updates AllowedTopics config and its validator
Browse files Browse the repository at this point in the history
  • Loading branch information
mkabtoul committed Dec 5, 2024
1 parent 48ee543 commit 0b367f1
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 85 deletions.
12 changes: 8 additions & 4 deletions ros_sugar/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,16 @@ def _wrapper(*, output, **_):
_wrapper.__name__ = func.__name__
return _wrapper

def attach_custom_callback(self, input_topic: Topic, callable: Callable) -> None:
def attach_custom_callback(self, input_topic: Topic, func: Callable) -> None:
"""
Method to attach custom method to subscriber callbacks
"""
if not callable(callable):
raise TypeError("A custom callback must be a Callable")
if not callable(func):
raise TypeError(f"A custom callback must be a Callable, got {type(func)}")
if callback := self.callbacks.get(input_topic.name):
if not callback:
raise TypeError("Specified input topic does not exist")
callback.on_callback_execute(callable)
callback.on_callback_execute(func)

def add_callback_postprocessor(self, input_topic: Topic, func: Callable) -> None:
"""Adds a callable as a post processor for topic callback.
Expand Down Expand Up @@ -1455,6 +1455,8 @@ def _replace_input_topic(
if idx:
self.in_topics.pop(idx)
self.in_topics.insert(idx, new_topic)
self.callbacks.pop(normalized_topic_name)
self.callbacks[new_name] = callback
return None

def _replace_output_topic(
Expand Down Expand Up @@ -1497,6 +1499,8 @@ def _replace_output_topic(
if idx:
self.out_topics.pop(idx)
self.out_topics.insert(idx, new_topic)
self.publishers_dict.pop(normalized_topic_name)
self.publishers_dict[new_name] = publisher
return None

@log_srv
Expand Down
16 changes: 8 additions & 8 deletions ros_sugar/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def _get_attribute_type(obj: Any, attrs: tuple):
result = getattr(result, attr)
return type(result)
except AttributeError as e:
raise AttributeError(f"Given attribute is not part of class {type(obj)}") from e
raise AttributeError(
f"Given nested attributes '{attrs}' are not part of class {type(obj)}"
) from e


def _check_attribute(cls, expected_type, attrs: tuple):
Expand Down Expand Up @@ -328,7 +330,6 @@ def __init__(
nested_attributes: Union[str, List[str]],
handle_once: bool = False,
keep_event_delay: float = 0.0,
topic_template: Optional[Topic] = None,
) -> None:
"""Creates an event
Expand Down Expand Up @@ -358,10 +359,7 @@ def __init__(

# Init from dictionary values
elif isinstance(event_source, Dict):
if topic_template:
self.set_dictionary(event_source, topic_template)
else:
self.dictionary = event_source
self.dictionary = event_source

elif isinstance(event_source, Topic):
self.event_topic = event_source
Expand All @@ -382,10 +380,12 @@ def __init__(

# Check if given trigger is of valid type
if trigger_value and not _check_attribute(
self.event_topic.ros_msg_type, type(self.trigger_ref_value), self._attrs
self.event_topic.msg_type._ros_type,
type(self.trigger_ref_value),
self._attrs,
):
raise TypeError(
f"Cannot initiate with trigger of type {type(trigger_value)} for a data of type {_get_attribute_type(self.event_topic.ros_msg_type, self._attrs)}"
f"Cannot initiate with trigger of type {type(trigger_value)} for a data of type {_get_attribute_type(self.event_topic.msg_type._ros_type, self._attrs)}"
)

# Init trigger as False
Expand Down
2 changes: 0 additions & 2 deletions ros_sugar/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

def json_to_events_list(
json_obj: Union[str, bytes, bytearray],
topic_template: Optional[Topic] = None,
) -> List:
"""
Loads a list of events from a JSON object
Expand Down Expand Up @@ -48,7 +47,6 @@ def json_to_events_list(
event_as_dict,
event_as_dict["trigger_ref_value"],
nested_attributes=[],
topic_template=topic_template,
)
# Add to events dictionary
events_list.append(deepcopy(new_event))
Expand Down
6 changes: 2 additions & 4 deletions ros_sugar/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from .publisher import Publisher
from .topic import (
Topic,
AllowedTopic,
RestrictedTopicsConfig,
AllowedTopics,
get_all_msg_types,
get_msg_type,
)
Expand All @@ -14,8 +13,7 @@
__all__ = [
"Publisher",
"Topic",
"AllowedTopic",
"RestrictedTopicsConfig",
"AllowedTopics",
"get_all_msg_types",
"get_msg_type",
]
6 changes: 5 additions & 1 deletion ros_sugar/io/supported_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def add_additional_datatypes(types: List[type]) -> None:
# Update the existing class with non-None attributes from the new class
existing_class = type_dict[new_type.__name__]

if existing_class == SupportedType:
# Skip parent
continue

if not existing_class.callback:
existing_class.callback = new_type.callback

Expand Down Expand Up @@ -303,7 +307,7 @@ def convert(

# Set MetaData
msg.info = ROSMapMetaData()
msg.info.map_load_time = msg_header.stamp
msg.info.map_load_time = msg.header.stamp
msg.info.width = output.shape[0]
msg.info.height = output.shape[1]
msg.info.resolution = resolution
Expand Down
74 changes: 11 additions & 63 deletions ros_sugar/io/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,12 @@ def _msg_type_validator(self, _, val):


@define(kw_only=True)
class AllowedTopic(BaseAttrs):
"""Configure a key name and allowed types to restrict a component Topic"""
class AllowedTopics(BaseAttrs):
"""Configure allowed types to restrict a component Topic"""

types: List[Union[type[supported_types.SupportedType], str]] = field(
converter=_get_msg_types,
validator=base_validators.list_contained_in(get_all_msg_types()),
converter=_get_msg_types
)
key: str = field(default="")
number_required: int = field(
default=1, validator=base_validators.in_range(min_value=0, max_value=100)
)
Expand All @@ -174,67 +172,17 @@ class AllowedTopic(BaseAttrs):
default=0, validator=base_validators.in_range(min_value=-1, max_value=100)
)

@types.validator
def _types_validator(self, _, vals):
msg_types = get_all_msg_types()
if any(v not in msg_types for v in vals):
raise ValueError(
f"Got value of 'msg_type': {vals}, which is not in available datatypes. Topics can only be created with one of the following types: { {msg_t.__name__: msg_t for msg_t in msg_types} }"
)

def __attrs_post_init__(self):
"""__attrs_post_init__."""
if self.number_required == 0 and self.number_optional == 0:
raise ValueError(
"Logical error - Cannot define an AllowedTopic with zero optional and required streams"
)


class RestrictedTopicsConfig:
"""
Class used to define a restriction on component inputs/outputs topics
"""

@classmethod
def keys(cls) -> List[str]:
"""keys.
:rtype: List[str]
"""
return [
member.key
for _, member in inspect.getmembers(
cls, lambda a: isinstance(a, AllowedTopic)
)
]

@classmethod
def types(cls, key: str) -> List[Union[supported_types.SupportedType, str]]:
"""types.
:param key:
:type key: str
:rtype: List[Union[supported_types.SupportedType, str]]
"""
for _, member in inspect.getmembers(cls, lambda a: isinstance(a, AllowedTopic)):
if member.key == key:
return member.types
raise KeyError(f"Unknown Topic key '{key}'")

@classmethod
def required_number(cls, key: str) -> int:
"""required_number.
:param key:
:type key: str
:rtype: int
"""
for _, member in inspect.getmembers(cls, lambda a: isinstance(a, AllowedTopic)):
if member.key == key:
return member.number_required
raise KeyError(f"Unknown Topic key '{key}'")

@classmethod
def optional_number(cls, key: str) -> int:
"""optional_number.
:param key:
:type key: str
:rtype: int
"""
for _, member in inspect.getmembers(cls, lambda a: isinstance(a, AllowedTopic)):
if member.key == key:
return member.number_optional
raise KeyError(f"Unknown Topic key '{key}'")
6 changes: 3 additions & 3 deletions ros_sugar/launch/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def add_pkg(
self._pkg_log_level[component.node_name] = ros_log_level
if self._config_file:
component._config_file = self._config_file
component.configure(self._config_file)
component.config_from_yaml(self._config_file)

def _setup_component_events_handlers(self, comp: BaseComponent):
"""Parse a component events/actions from the overall components actions
Expand Down Expand Up @@ -705,12 +705,12 @@ def configure(
if component_name:
for component in self.components:
if component.node_name == component_name:
component.configure(config_file)
component.config_from_yaml(config_file)
return

# If no component is specified -> configure all components
for component in self.components:
component.configure(config_file)
component.config_from_yaml(config_file)

def add_py_executable(self, path_to_executable: str, name: str = "python3"):
"""
Expand Down

0 comments on commit 0b367f1

Please sign in to comment.