Skip to content

Commit

Permalink
feat: refactor message_reader to avoid python wrappers (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Nov 4, 2022
1 parent c65bcd7 commit b81de45
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 13 deletions.
4 changes: 4 additions & 0 deletions src/dbus_fast/_private/unmarshaller.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ cdef class Unmarshaller:
cdef object _int16_unpack
cdef object _uint16_unpack

cdef _reset(self)

cpdef reset(self)

@cython.locals(
Expand Down Expand Up @@ -173,6 +175,8 @@ cdef class Unmarshaller:
)
cdef _read_body(self)

cdef _unmarshall(self)

cpdef unmarshall(self)

@cython.locals(
Expand Down
16 changes: 16 additions & 0 deletions src/dbus_fast/_private/unmarshaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ def __init__(self, stream: io.BufferedRWPair, sock: Optional[socket.socket] = No
def reset(self) -> None:
"""Reset the unmarshaller to its initial state.
Call this before processing a new message.
"""
self._reset()

def _reset(self) -> None:
"""Reset the unmarshaller to its initial state.
Call this before processing a new message.
"""
self._unix_fds = []
Expand Down Expand Up @@ -596,6 +603,15 @@ def _read_body(self) -> None:
def unmarshall(self) -> Optional[Message]:
"""Unmarshall the message.
The underlying read function will raise BlockingIOError if the
if there are not enough bytes in the buffer. This allows unmarshall
to be resumed when more data comes in over the wire.
"""
return self._unmarshall()

def _unmarshall(self) -> Optional[Message]:
"""Unmarshall the message.
The underlying read function will raise BlockingIOError if the
if there are not enough bytes in the buffer. This allows unmarshall
to be resumed when more data comes in over the wire.
Expand Down
10 changes: 2 additions & 8 deletions src/dbus_fast/aio/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def __init__(
super().__init__(bus_address, bus_type, ProxyObject)
self._negotiate_unix_fd = negotiate_unix_fd
self._loop = asyncio.get_running_loop()
self._unmarshaller = self._create_unmarshaller()

self._writer = _MessageWriter(self)

Expand Down Expand Up @@ -201,7 +200,8 @@ async def connect(self) -> "MessageBus":
self._loop.add_reader(
self._fd,
build_message_reader(
self._unmarshaller,
self._stream,
self._sock if self._negotiate_unix_fd else None,
self._process_message,
self._finalize,
),
Expand Down Expand Up @@ -477,12 +477,6 @@ def disconnect(self) -> None:
except Exception:
logging.warning("could not close socket", exc_info=True)

def _create_unmarshaller(self) -> Unmarshaller:
sock = None
if self._negotiate_unix_fd:
sock = self._sock
return Unmarshaller(self._stream, sock)

def _finalize(self, err: Optional[Exception] = None) -> None:
try:
self._loop.remove_reader(self._fd)
Expand Down
2 changes: 2 additions & 0 deletions src/dbus_fast/aio/message_reader.pxd
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""cdefs for message_reader.py"""

import cython

from .._private.unmarshaller cimport Unmarshaller
12 changes: 7 additions & 5 deletions src/dbus_fast/aio/message_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import logging
import socket
import traceback
from typing import Callable, Optional

Expand All @@ -7,19 +9,19 @@


def build_message_reader(
unmarshaller: Unmarshaller,
stream: io.BufferedRWPair,
sock: Optional[socket.socket],
process: Callable[[Message], None],
finalize: Callable[[Optional[Exception]], None],
) -> None:
"""Build a callable that reads messages from the unmarshaller and passes them to the process function."""
unmarshall = unmarshaller.unmarshall
reset = unmarshaller.reset
unmarshaller = Unmarshaller(stream, sock)

def _message_reader() -> None:
"""Reads messages from the unmarshaller and passes them to the process function."""
try:
while True:
message = unmarshall()
message = unmarshaller._unmarshall()
if not message:
return
try:
Expand All @@ -28,7 +30,7 @@ def _message_reader() -> None:
logging.error(
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}"
)
reset()
unmarshaller._reset()
except Exception as e:
finalize(e)

Expand Down

0 comments on commit b81de45

Please sign in to comment.