diff --git a/example_apps/call_method_tcp.py b/example_apps/call_method_tcp.py index 804bcdb..7fea0ea 100644 --- a/example_apps/call_method_tcp.py +++ b/example_apps/call_method_tcp.py @@ -3,7 +3,10 @@ import logging from someipy import TransportLayerProtocol -from someipy.client_service_instance import MethodResult, construct_client_service_instance +from someipy.client_service_instance import ( + MethodResult, + construct_client_service_instance, +) from someipy.service import ServiceBuilder from someipy.service_discovery import construct_service_discovery from someipy.logging import set_someipy_log_level @@ -17,6 +20,7 @@ SAMPLE_INSTANCE_ID = 0x5678 SAMPLE_METHOD_ID = 0x0123 + async def main(): # It's possible to configure the logging level of the someipy library, e.g. logging.INFO, logging.DEBUG, logging.WARN, .. @@ -25,8 +29,10 @@ async def main(): # Since the construction of the class ServiceDiscoveryProtocol is not trivial and would require an async __init__ function # use the construct_service_discovery function # The local interface IP address needs to be passed so that the src-address of all SD UDP packets is correctly set - service_discovery = await construct_service_discovery(SD_MULTICAST_GROUP, SD_PORT, INTERFACE_IP) - + service_discovery = await construct_service_discovery( + SD_MULTICAST_GROUP, SD_PORT, INTERFACE_IP + ) + addition_service = ( ServiceBuilder() .with_service_id(SAMPLE_SERVICE_ID) @@ -41,7 +47,7 @@ async def main(): endpoint=(ipaddress.IPv4Address(INTERFACE_IP), 3002), ttl=5, sd_sender=service_discovery, - protocol=TransportLayerProtocol.TCP + protocol=TransportLayerProtocol.TCP, ) # The service instance has to be attached always to the ServiceDiscoveryProtocol object, so that the service instance @@ -50,11 +56,18 @@ async def main(): try: while True: - + method_parameter = Addends(addend1=1, addend2=2) - method_success, method_result = await client_instance_addition.call_method(SAMPLE_METHOD_ID, method_parameter.serialize()) + + # The call method function returns a tuple with the first element being a MethodResult enum + method_success, method_result = await client_instance_addition.call_method( + SAMPLE_METHOD_ID, method_parameter.serialize() + ) + # Check the result of the method call and handle it accordingly if method_success == MethodResult.SUCCESS: - print(f"Received result for method: {' '.join(f'0x{b:02x}' for b in method_result)}") + print( + f"Received result for method: {' '.join(f'0x{b:02x}' for b in method_result)}" + ) try: sum = Sum().deserialize(method_result) print(f"Sum: {sum.value.value}") @@ -68,6 +81,8 @@ async def main(): print("Service not yet available..") await asyncio.sleep(2) + + # When the application is canceled by the user, the asyncio.CancelledError is raised except asyncio.CancelledError: print("Shutdown..") finally: @@ -81,7 +96,4 @@ async def main(): if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - pass + asyncio.run(main()) diff --git a/src/someipy/_internal/someip_data_processor.py b/src/someipy/_internal/someip_data_processor.py index b5fdaf8..9c17676 100644 --- a/src/someipy/_internal/someip_data_processor.py +++ b/src/someipy/_internal/someip_data_processor.py @@ -1,3 +1,4 @@ +from enum import Enum import struct from someipy._internal.someip_header import SomeIpHeader from someipy._internal.someip_message import SomeIpMessage @@ -5,60 +6,62 @@ class SomeipDataProcessor: + class State(Enum): + HEADER = 1 + PAYLOAD = 2 + PENDING = 3 + def __init__(self, datagram_mode=True): - self._buffer = bytes() - self._expected_bytes = 0 + self._reset() self._datagram_mode = datagram_mode - self.someip_message = None + self._someip_message = None def _reset(self): + self._state = SomeipDataProcessor.State.HEADER self._buffer = bytes() - self._expected_bytes = 0 + self._expected_bytes = 8 # 2x 32-bit for header + self._total_length = 0 def process_data(self, new_data: bytes) -> bool: - received_length = len(new_data) - - # UDP case - if self._datagram_mode: - header = SomeIpHeader.from_buffer(new_data) - expected_total_length = 8 + header.length - payload_length = expected_total_length - 16 - if received_length == expected_total_length: - self.someip_message = SomeIpMessage(header=header, payload=new_data[16:]) - return True + self._buffer += new_data + current_length = len(self._buffer) + + if self._state == SomeipDataProcessor.State.HEADER: + if current_length < self._expected_bytes: + # The header was not fully received yet + return False else: - # Malformed package -> return False + _, _, length = struct.unpack(">HHI", self._buffer[0:8]) + self._total_length = length + 8 + self._expected_bytes = self._total_length - current_length + self._state = SomeipDataProcessor.State.PAYLOAD + + elif self._state == SomeipDataProcessor.State.PAYLOAD: + if current_length < self._total_length: + # The payload was not fully received yet + self._expected_bytes = self._total_length - current_length return False - - # From here on: TCP case - if self._expected_bytes == 0 and len(self._buffer) == 0: + else: + payload_length = self._total_length - 16 + header = SomeIpHeader.from_buffer(self._buffer) + self._someip_message = SomeIpMessage( + header=header, payload=self._buffer[16 : (16 + payload_length)] + ) - if received_length >= 8: - service_id, method_id, length = struct.unpack(">HHI", new_data[0:8]) - expected_total_length = 8 + length - payload_length = expected_total_length - 16 + self._state = SomeipDataProcessor.State.HEADER + # If more data was received over the current message boundary, keep the data + self._buffer = self._buffer[self._total_length :] + self._expected_bytes = 8 + self._total_length = 0 - # Case 1: Received exactly one SOME/IP message - if received_length == expected_total_length: - header = SomeIpHeader.from_buffer(new_data) - self.someip_message = SomeIpMessage(header=header, payload=new_data[16:(16+payload_length)]) - self._reset() - return True - # Case 2: Received less bytes than expected - elif received_length < expected_total_length: - self._expected_bytes = (expected_total_length - received_length) - self._buffer = new_data - return False - # Case 3: Received more bytes than expected - elif received_length > expected_total_length: - # Assume it is the beginning of a new SOME/IP message, store remaining bytes in buffer - end_payload = 16 + payload_length - header = SomeIpHeader.from_buffer(new_data) - self.someip_message = SomeIpMessage(header=header, payload=new_data[16:end_payload]) - self._buffer = new_data[end_payload:] - self._expected_bytes = 0 + return True - return True - - else: - pass # store in buffer \ No newline at end of file + @property + def someip_message(self): + """Returns the SomeIpMessage that was received and interpreted""" + return self._someip_message + + @property + def expected_bytes(self): + """Returns the number of bytes that are expected to complete the current message""" + return self._expected_bytes diff --git a/src/someipy/client_service_instance.py b/src/someipy/client_service_instance.py index c467142..511758b 100644 --- a/src/someipy/client_service_instance.py +++ b/src/someipy/client_service_instance.py @@ -4,6 +4,7 @@ from typing import Iterable, Tuple, Callable, Set, List from someipy import Service +from someipy._internal.someip_data_processor import SomeipDataProcessor from someipy._internal.someip_sd_header import ( SdService, TransportLayerProtocol, @@ -192,7 +193,7 @@ async def call_method( endpoint_to_str_int_tuple(self._found_services[0].service.endpoint), ) - # After sending the method call wait for one second + # After sending the method call wait for maximum one second try: await asyncio.wait_for(self._method_call_future, 1.0) except asyncio.TimeoutError: @@ -206,7 +207,6 @@ async def call_method( def someip_message_received( self, someip_message: SomeIpMessage, addr: Tuple[str, int] ) -> None: - print("Some ip message received") if ( someip_message.header.client_id == 0x00 and someip_message.header.message_type == MessageType.NOTIFICATION.value @@ -274,9 +274,6 @@ def find_service_update(self): pass def offer_service_update(self, offered_service: SdService): - # if len(self._eventgroups_to_subscribe) == 0: - # return - if self._service.id != offered_service.service_id: return if self._instance_id != offered_service.instance_id: @@ -289,6 +286,9 @@ def offer_service_update(self, offered_service: SdService): if FoundService(offered_service) not in self._found_services: self._found_services.append(FoundService(offered_service)) + if len(self._eventgroups_to_subscribe) == 0: + return + # Try to subscribe to requested event groups for eventgroup_to_subscribe in self._eventgroups_to_subscribe: ( @@ -362,66 +362,24 @@ async def setup_tcp_connection( if self._tcp_connection.is_open(): self._tcp_connection_established_event.set() - class State(Enum): - HEADER = 1 - PAYLOAD = 2 - PENDING = 3 + get_logger(_logger_name).debug(f"Start reading on port {src_port}") - state = State.HEADER - - expected_bytes = 8 # 2x 32-bit for header - header_data = bytes() - data = bytes() - get_logger(_logger_name).debug(f"Start TCP read on port {src_port}") + someip_processor = SomeipDataProcessor() while self._tcp_connection.is_open(): try: - if state == State.HEADER: - while len(data) < expected_bytes: - new_data = await asyncio.wait_for( - self._tcp_connection.reader.read(8), 3.0 - ) - print("Received data") - data += new_data - service_id, method_id, length = struct.unpack( - ">HHI", data[0:8] - ) - header_data = data[0:8] - - # The length bytes also covers 8 bytes header data without payload - expected_bytes = length - state = State.PAYLOAD - - elif state == State.PAYLOAD: - data = bytes() - while len(data) < expected_bytes: - new_data = await asyncio.wait_for( - self._tcp_connection.reader.read(expected_bytes), - 3.0, - ) - data += new_data - - # Request ID to return code is also covered in the payload state, but needed for the SOME/IP header - header_data = header_data + data[0:8] - payload_data = data[8 : (8 + expected_bytes)] - - message_data = header_data + payload_data - someip_header = SomeIpHeader.from_buffer(buf=message_data) - someip_message = SomeIpMessage(someip_header, payload_data) + new_data = await asyncio.wait_for( + self._tcp_connection.reader.read( + someip_processor.expected_bytes + ), + 3.0, + ) + if someip_processor.process_data(new_data): self.someip_message_received( - someip_message, (dst_ip, dst_port) + someip_processor.someip_message, (dst_ip, dst_port) ) - if len(data) == expected_bytes: - # If the exact amount of needed bytes were received reset the buffer - data = bytes() - # If more data was received, keep the remaining part for the next reception - # TODO: this needs more logic - data = data[expected_bytes:] - state = State.HEADER - expected_bytes = 8 - except asyncio.TimeoutError: get_logger(_logger_name).debug( f"Timeout reading from TCP connection ({src_ip}, {src_port})" diff --git a/src/someipy/server_service_instance.py b/src/someipy/server_service_instance.py index e984778..310192a 100644 --- a/src/someipy/server_service_instance.py +++ b/src/someipy/server_service_instance.py @@ -11,9 +11,7 @@ build_offer_service_sd_header, build_subscribe_eventgroup_ack_sd_header, ) -from someipy._internal.someip_header import ( - SomeIpHeader -) +from someipy._internal.someip_header import SomeIpHeader from someipy._internal.someip_sd_header import ( SdService, TransportLayerProtocol, @@ -40,6 +38,7 @@ _logger_name = "server_service_instance" + class ServerServiceInstance(ServiceDiscoveryObserver): _service: Service _instance_id: int @@ -119,7 +118,9 @@ def send_event(self, event_group_id: int, event_id: int, payload: bytes) -> None endpoint_to_str_int_tuple(sub.endpoint), ) - def someip_message_received(self, message: SomeIpMessage, addr: Tuple[str, int]) -> None: + def someip_message_received( + self, message: SomeIpMessage, addr: Tuple[str, int] + ) -> None: """ Handle a received Some/IP message, typically when a client uses an offered service. @@ -137,7 +138,6 @@ def someip_message_received(self, message: SomeIpMessage, addr: Tuple[str, int]) - The protocol and interface version are not checked yet. - If the message type in the received header is not a request, a warning is logged. """ - print("Received message") header = message.header payload_to_return = bytes() header_to_return = header @@ -148,8 +148,8 @@ def send_response(): # Update length in header to the correct length header_to_return.length = 8 + len(payload_to_return) self._someip_endpoint.sendto( - header_to_return.to_buffer() + payload_to_return, addr - ) + header_to_return.to_buffer() + payload_to_return, addr + ) if header.service_id != self._service.id: get_logger(_logger_name).warn( @@ -160,7 +160,6 @@ def send_response(): send_response() return - if header.method_id not in self._service.methods.keys(): get_logger(_logger_name).warn( f"Unknown method ID received from {addr}: ID 0x{header.method_id:04X}" @@ -185,7 +184,7 @@ def send_response(): payload_to_return = payload_result send_response() - + else: get_logger(_logger_name).warn( f"Unknown message type received from {addr}: Type 0x{header.message_type:04X}" @@ -232,7 +231,7 @@ def subscribe_eventgroup_update( None """ - + # From SD specification: # [PRS_SOMEIPSD_00829] When receiving a SubscribeEventgroupAck or Sub- # scribeEventgroupNack the Service ID, Instance ID, Eventgroup ID, and Major Ver- diff --git a/tests/test_someip_data_processor.py b/tests/test_someip_data_processor.py index 1c37b87..066f2ac 100644 --- a/tests/test_someip_data_processor.py +++ b/tests/test_someip_data_processor.py @@ -49,24 +49,24 @@ def corrupt_someip_message() -> SomeIpMessage: def test_process_with_datagrams(valid_someip_message): data = valid_someip_message.header.to_buffer() + valid_someip_message.payload - processor = SomeipDataProcessor(datagram_mode=True) + processor = SomeipDataProcessor() result = processor.process_data(data) assert result is True assert processor.someip_message.header == valid_someip_message.header - assert processor._expected_bytes == 0 + assert processor.expected_bytes == 8 assert len(processor._buffer) == 0 result = processor.process_data(data) assert result is True assert processor.someip_message.header == valid_someip_message.header - assert processor._expected_bytes == 0 + assert processor._expected_bytes == 8 assert len(processor._buffer) == 0 def test_process_with_malformed_datagrams(corrupt_someip_message): data = corrupt_someip_message.header.to_buffer() + corrupt_someip_message.payload - processor = SomeipDataProcessor(datagram_mode=True) + processor = SomeipDataProcessor() result = processor.process_data(data)