diff --git a/src/dbus_fast/_private/unmarshaller.pxd b/src/dbus_fast/_private/unmarshaller.pxd index 4073cf0c..667f8d7e 100644 --- a/src/dbus_fast/_private/unmarshaller.pxd +++ b/src/dbus_fast/_private/unmarshaller.pxd @@ -106,6 +106,8 @@ cdef class Unmarshaller: cdef object _int16_unpack cdef object _uint16_unpack + cdef _reset(self) + cpdef reset(self) @cython.locals( @@ -173,6 +175,8 @@ cdef class Unmarshaller: ) cdef _read_body(self) + cdef _unmarshall(self) + cpdef unmarshall(self) @cython.locals( diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 06b0535b..d5695d2f 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -202,6 +202,13 @@ def __init__(self, stream: io.BufferedRWPair, sock: Optional[socket.socket] = No def reset(self) -> None: """Reset the unmarshaller to its initial state. + Call this before processing a new message. + """ + self._reset() + + def _reset(self) -> None: + """Reset the unmarshaller to its initial state. + Call this before processing a new message. """ self._unix_fds = [] @@ -596,6 +603,15 @@ def _read_body(self) -> None: def unmarshall(self) -> Optional[Message]: """Unmarshall the message. + The underlying read function will raise BlockingIOError if the + if there are not enough bytes in the buffer. This allows unmarshall + to be resumed when more data comes in over the wire. + """ + return self._unmarshall() + + def _unmarshall(self) -> Optional[Message]: + """Unmarshall the message. + The underlying read function will raise BlockingIOError if the if there are not enough bytes in the buffer. This allows unmarshall to be resumed when more data comes in over the wire. diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 3f4ab529..090d0cb3 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -170,7 +170,6 @@ def __init__( super().__init__(bus_address, bus_type, ProxyObject) self._negotiate_unix_fd = negotiate_unix_fd self._loop = asyncio.get_running_loop() - self._unmarshaller = self._create_unmarshaller() self._writer = _MessageWriter(self) @@ -201,7 +200,8 @@ async def connect(self) -> "MessageBus": self._loop.add_reader( self._fd, build_message_reader( - self._unmarshaller, + self._stream, + self._sock if self._negotiate_unix_fd else None, self._process_message, self._finalize, ), @@ -477,12 +477,6 @@ def disconnect(self) -> None: except Exception: logging.warning("could not close socket", exc_info=True) - def _create_unmarshaller(self) -> Unmarshaller: - sock = None - if self._negotiate_unix_fd: - sock = self._sock - return Unmarshaller(self._stream, sock) - def _finalize(self, err: Optional[Exception] = None) -> None: try: self._loop.remove_reader(self._fd) diff --git a/src/dbus_fast/aio/message_reader.pxd b/src/dbus_fast/aio/message_reader.pxd index a53a8e3e..e76627b2 100644 --- a/src/dbus_fast/aio/message_reader.pxd +++ b/src/dbus_fast/aio/message_reader.pxd @@ -1,3 +1,5 @@ """cdefs for message_reader.py""" import cython + +from .._private.unmarshaller cimport Unmarshaller diff --git a/src/dbus_fast/aio/message_reader.py b/src/dbus_fast/aio/message_reader.py index 4da1f08e..35d68834 100644 --- a/src/dbus_fast/aio/message_reader.py +++ b/src/dbus_fast/aio/message_reader.py @@ -1,4 +1,6 @@ +import io import logging +import socket import traceback from typing import Callable, Optional @@ -7,19 +9,19 @@ def build_message_reader( - unmarshaller: Unmarshaller, + stream: io.BufferedRWPair, + sock: Optional[socket.socket], process: Callable[[Message], None], finalize: Callable[[Optional[Exception]], None], ) -> None: """Build a callable that reads messages from the unmarshaller and passes them to the process function.""" - unmarshall = unmarshaller.unmarshall - reset = unmarshaller.reset + unmarshaller = Unmarshaller(stream, sock) def _message_reader() -> None: """Reads messages from the unmarshaller and passes them to the process function.""" try: while True: - message = unmarshall() + message = unmarshaller._unmarshall() if not message: return try: @@ -28,7 +30,7 @@ def _message_reader() -> None: logging.error( f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" ) - reset() + unmarshaller._reset() except Exception as e: finalize(e)