diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 47b7c3b8..93b0a00b 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -80,16 +80,16 @@ class MarshallerStreamEndError(Exception): # # For any align value, the correct padding formula is: # -# (align - (offset % align)) % align +# (align - (pos % align)) % align # # However, if align is a power of 2 (always the case here), the slow MOD # operator can be replaced by a bitwise AND: # -# (align - (offset & (align - 1))) & (align - 1) +# (align - (pos & (align - 1))) & (align - 1) # # Which can be simplified to: # -# (-offset) & (align - 1) +# (-pos) & (align - 1) # # class Unmarshaller: @@ -104,7 +104,7 @@ class Unmarshaller: "unix_fds", "buf", "view", - "offset", + "pos", "stream", "sock", "message", @@ -115,13 +115,14 @@ class Unmarshaller: "message_type", "flag", "msg_len", + "_uint32_unpack", ) def __init__(self, stream: io.BufferedRWPair, sock=None): self.unix_fds: List[int] = [] self.buf = bytearray() # Actual buffer self.view = None # Memory view of the buffer - self.offset = 0 + self.pos = 0 self.stream = stream self.sock = sock self.message = None @@ -132,6 +133,8 @@ def __init__(self, stream: io.BufferedRWPair, sock=None): self.message_type: MessageType | None = None self.flag: MessageFlag | None = None self.msg_len = 0 + # Only set if we cannot cast + self._uint32_unpack: Callable | None = None def read_sock(self, length: int) -> bytes: """reads from the socket, storing any fds sent and handling errors @@ -155,21 +158,21 @@ def read_sock(self, length: int) -> bytes: return msg - def read_to_offset(self, offset: int) -> None: + def read_to_pos(self, pos: int) -> None: """ Read from underlying socket into buffer. Raises MarshallerStreamEndError if there is not enough data to be read. - :arg offset: - The offset to read to. If not enough bytes are available in the + :arg pos: + The pos to read to. If not enough bytes are available in the buffer, read more from it. :returns: None """ start_len = len(self.buf) - missing_bytes = offset - (start_len - self.offset) + missing_bytes = pos - (start_len - self.pos) if self.sock is None: data = self.stream.read(missing_bytes) else: @@ -179,7 +182,7 @@ def read_to_offset(self, offset: int) -> None: if data is None: raise MarshallerStreamEndError() self.buf.extend(data) - if len(data) + start_len != offset: + if len(data) + start_len != pos: raise MarshallerStreamEndError() def read_boolean(self, _=None): @@ -187,32 +190,27 @@ def read_boolean(self, _=None): def read_string_cast(self, _=None): """Read a string using cast.""" - self.offset += UINT32_SIZE + (-self.offset & (UINT32_SIZE - 1)) # align - str_start = self.offset + self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align + str_start = self.pos # read terminating '\0' byte as well (str_length + 1) - self.offset += ( - self.view[self.offset - UINT32_SIZE : self.offset].cast(UINT32_CAST)[0] + 1 + self.pos += ( + self.view[self.pos - UINT32_SIZE : self.pos].cast(UINT32_CAST)[0] + 1 ) - return self.buf[str_start : self.offset - 1].decode() + return self.buf[str_start : self.pos - 1].decode() def read_string_unpack(self, _=None): """Read a string using unpack.""" - self.offset += UINT32_SIZE + (-self.offset & (UINT32_SIZE - 1)) # align - str_start = self.offset + self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align + str_start = self.pos # read terminating '\0' byte as well (str_length + 1) - self.offset += ( - self.readers[UINT32_DBUS_TYPE][3].unpack_from( - self.view, str_start - UINT32_SIZE - )[0] - + 1 - ) - return self.buf[str_start : self.offset - 1].decode() + self.pos += self._uint32_unpack(self.view, str_start - UINT32_SIZE)[0] + 1 + return self.buf[str_start : self.pos - 1].decode() def read_signature(self, _=None): - signature_len = self.view[self.offset] # byte - o = self.offset + 1 + signature_len = self.view[self.pos] # byte + o = self.pos + 1 # read terminating '\0' byte as well (str_length + 1) - self.offset = o + signature_len + 1 + self.pos = o + signature_len + 1 return self.buf[o : o + signature_len].decode() def read_variant(self, _=None): @@ -221,39 +219,48 @@ def read_variant(self, _=None): return Variant(tree, self.read_argument(tree.types[0]), verify=False) def read_struct(self, type_: SignatureType): - self.offset += -self.offset & 7 # align 8 + self.pos += -self.pos & 7 # align 8 return [self.read_argument(child_type) for child_type in type_.children] def read_dict_entry(self, type_: SignatureType): - self.offset += -self.offset & 7 # align 8 + self.pos += -self.pos & 7 # align 8 return self.read_argument(type_.children[0]), self.read_argument( type_.children[1] ) def read_array(self, type_: SignatureType): - self.offset += -self.offset & 3 # align 4 for the array - array_length = self.read_argument(UINT32_SIGNATURE) + self.pos += -self.pos & 3 # align 4 for the array + self.pos += ( + -self.pos & (UINT32_SIZE - 1) + ) + UINT32_SIZE # align for the uint32 + if self._uint32_unpack: + array_length = self._uint32_unpack(self.view, self.pos - UINT32_SIZE)[0] + else: + array_length = self.view[self.pos - UINT32_SIZE : self.pos].cast( + UINT32_CAST + )[0] child_type = type_.children[0] if child_type.token in "xtd{(": # the first alignment is not included in the array size - self.offset += -self.offset & 7 # align 8 + self.pos += -self.pos & 7 # align 8 if child_type.token == "y": - self.offset += array_length - return self.buf[self.offset - array_length : self.offset] + self.pos += array_length + return self.buf[self.pos - array_length : self.pos] - beginning_offset = self.offset + beginning_pos = self.pos if child_type.token == "{": result_dict = {} - while self.offset - beginning_offset < array_length: - key, value = self.read_dict_entry(child_type) - result_dict[key] = value + while self.pos - beginning_pos < array_length: + self.pos += -self.pos & 7 # align 8 + key = self.read_argument(child_type.children[0]) + result_dict[key] = self.read_argument(child_type.children[1]) return result_dict result_list = [] - while self.offset - beginning_offset < array_length: + while self.pos - beginning_pos < array_length: result_list.append(self.read_argument(child_type)) return result_list @@ -262,24 +269,24 @@ def read_argument(self, type_: SignatureType) -> Any: reader, ctype, size, struct = self.readers[type_.token] if reader: # complex type return reader(self, type_) - self.offset += size + (-self.offset & (size - 1)) # align + self.pos += size + (-self.pos & (size - 1)) # align if struct: # struct only set if we cannot cast - return struct.unpack_from(self.view, self.offset - size)[0] - return self.view[self.offset - size : self.offset].cast(ctype)[0] + return struct(self.view, self.pos - size)[0] + return self.view[self.pos - size : self.pos].cast(ctype)[0] def header_fields(self, header_length): """Header fields are always a(yv).""" - beginning_offset = self.offset + beginning_pos = self.pos headers = {} - while self.offset - beginning_offset < header_length: + while self.pos - beginning_pos < header_length: # Now read the y (byte) of struct (yv) - self.offset += (-self.offset & 7) + 1 # align 8 + 1 for 'y' byte - field_0 = self.view[self.offset - 1] + self.pos += (-self.pos & 7) + 1 # align 8 + 1 for 'y' byte + field_0 = self.view[self.pos - 1] # Now read the v (variant) of struct (yv) - signature_len = self.view[self.offset] # byte - o = self.offset + 1 - self.offset += signature_len + 2 # one for the byte, one for the '\0' + signature_len = self.view[self.pos] # byte + o = self.pos + 1 + self.pos += signature_len + 2 # one for the byte, one for the '\0' tree = SignatureTree._get(self.buf[o : o + signature_len].decode()) headers[HEADER_NAME_MAP[field_0]] = self.read_argument(tree.types[0]) return headers @@ -288,7 +295,7 @@ def _read_header(self): """Read the header of the message.""" # Signature is of the header is # BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT) - self.read_to_offset(HEADER_SIGNATURE_SIZE) + self.read_to_pos(HEADER_SIGNATURE_SIZE) buffer = self.buf endian = buffer[0] self.message_type = MESSAGE_TYPE_MAP[buffer[1]] @@ -310,23 +317,21 @@ def _read_header(self): self.msg_len = ( self.header_len + (-self.header_len & 7) + self.body_len ) # align 8 - self.readers = self._readers_by_type[ - ( - endian, - bool( - (IS_LITTLE_ENDIAN and endian == LITTLE_ENDIAN) - or (IS_BIG_ENDIAN and endian == BIG_ENDIAN) - ), - ) - ] + can_cast = bool( + (IS_LITTLE_ENDIAN and endian == LITTLE_ENDIAN) + or (IS_BIG_ENDIAN and endian == BIG_ENDIAN) + ) + self.readers = self._readers_by_type[(endian, can_cast)] + if not can_cast: + self._uint32_unpack = self.readers[UINT32_DBUS_TYPE][3] def _read_body(self): """Read the body of the message.""" - self.read_to_offset(HEADER_SIGNATURE_SIZE + self.msg_len) + self.read_to_pos(HEADER_SIGNATURE_SIZE + self.msg_len) self.view = memoryview(self.buf) - self.offset = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION + self.pos = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION header_fields = self.header_fields(self.header_len) - self.offset += -self.offset & 7 # align 8 + self.pos += -self.pos & 7 # align 8 tree = SignatureTree._get(header_fields.get(HeaderField.SIGNATURE.name, "")) self.message = Message( destination=header_fields.get(HEADER_DESTINATION), @@ -389,7 +394,7 @@ def unmarshall(self): } _ctype_by_endian: Dict[ - Tuple[int, bool], Dict[str, Tuple[None, str, int, Struct]] + Tuple[int, bool], Dict[str, Tuple[None, str, int, Callable]] ] = { endian_can_cast: { dbus_type: ( @@ -397,7 +402,9 @@ def unmarshall(self): *ctype_size, None if endian_can_cast[1] - else Struct(f"{UNPACK_SYMBOL[endian_can_cast[0]]}{ctype_size[0]}"), + else Struct( + f"{UNPACK_SYMBOL[endian_can_cast[0]]}{ctype_size[0]}" + ).unpack_from, ) for dbus_type, ctype_size in DBUS_TO_CTYPE.items() } diff --git a/src/dbus_fast/signature.py b/src/dbus_fast/signature.py index 16bd9e9b..46d58fd3 100644 --- a/src/dbus_fast/signature.py +++ b/src/dbus_fast/signature.py @@ -417,15 +417,12 @@ def __init__( value: Any, verify: bool = True, ): - signature_str = "" - signature_tree = None - signature_type = None - if type(signature) is SignatureTree: signature_tree = signature elif type(signature) is SignatureType: signature_type = signature signature_str = signature.signature + signature_tree = None elif type(signature) is str: signature_tree = SignatureTree._get(signature) else: