Skip to content

Commit

Permalink
Add support for PubSub with RESP3 parser (#2721)
Browse files Browse the repository at this point in the history
* add resp3 pubsub

* linters

* _set_info_logger func

* async pubsun

* docstring
  • Loading branch information
dvora-h authored Apr 24, 2023
1 parent 0db4eba commit a96a38a
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 30 deletions.
20 changes: 16 additions & 4 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import safe_str, str_if_bytes
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes

PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
Expand Down Expand Up @@ -658,6 +658,7 @@ def __init__(
shard_hint: Optional[str] = None,
ignore_subscribe_messages: bool = False,
encoder=None,
push_handler_func: Optional[Callable] = None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -666,6 +667,7 @@ def __init__(
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
self.push_handler_func = push_handler_func
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
if self.encoder.decode_responses:
Expand All @@ -678,6 +680,8 @@ def __init__(
b"pong",
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
if self.push_handler_func is None:
_set_info_logger()
self.channels = {}
self.pending_unsubscribe_channels = set()
self.patterns = {}
Expand Down Expand Up @@ -757,6 +761,8 @@ async def connect(self):
self.connection.register_connect_callback(self.on_connect)
else:
await self.connection.connect()
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)

async def _disconnect_raise_connect(self, conn, error):
"""
Expand Down Expand Up @@ -797,7 +803,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
await conn.connect()

read_timeout = None if block else timeout
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
response = await self._execute(
conn, conn.read_response, timeout=read_timeout, push_request=True
)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down Expand Up @@ -927,15 +935,19 @@ def ping(self, message=None) -> Awaitable:
"""
Ping the Redis server
"""
message = "" if message is None else message
return self.execute_command("PING", message)
args = ["PING", message] if message is not None else ["PING"]
return self.execute_command(*args)

async def handle_message(self, response, ignore_subscribe_messages=False):
"""
Parses a pub/sub message. If the channel or pattern was subscribed to
with a message handler, the handler is invoked instead of a parsed
message being returned.
"""
if response is None:
return None
if isinstance(response, bytes):
response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {
Expand Down
16 changes: 15 additions & 1 deletion redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,29 @@ async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
push_request: Optional[bool] = False,
):
"""Read the response from a previously sent command"""
read_timeout = timeout if timeout is not None else self.socket_timeout
try:
if read_timeout is not None:
if (
read_timeout is not None
and self.protocol == "3"
and not HIREDIS_AVAILABLE
):
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
elif read_timeout is not None:
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
elif self.protocol == "3" and not HIREDIS_AVAILABLE:
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
response = await self._parser.read_response(
disable_decoding=disable_decoding
Expand Down
16 changes: 12 additions & 4 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from redis.lock import Lock
from redis.retry import Retry
from redis.utils import safe_str, str_if_bytes
from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes

SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
Expand Down Expand Up @@ -1429,6 +1429,7 @@ def __init__(
shard_hint=None,
ignore_subscribe_messages=False,
encoder=None,
push_handler_func=None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -1438,13 +1439,16 @@ def __init__(
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
self.push_handler_func = push_handler_func
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
if self.encoder.decode_responses:
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
self.health_check_response = [b"pong", self.health_check_response_b]
if self.push_handler_func is None:
_set_info_logger()
self.reset()

def __enter__(self):
Expand Down Expand Up @@ -1515,6 +1519,8 @@ def execute_command(self, *args):
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
Expand Down Expand Up @@ -1580,7 +1586,7 @@ def try_read():
return None
else:
conn.connect()
return conn.read_response()
return conn.read_response(push_request=True)

response = self._execute(conn, try_read)

Expand Down Expand Up @@ -1739,8 +1745,8 @@ def ping(self, message=None):
"""
Ping the Redis server
"""
message = "" if message is None else message
return self.execute_command("PING", message)
args = ["PING", message] if message is not None else ["PING"]
return self.execute_command(*args)

def handle_message(self, response, ignore_subscribe_messages=False):
"""
Expand All @@ -1750,6 +1756,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
"""
if response is None:
return None
if isinstance(response, bytes):
response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {
Expand Down
12 changes: 9 additions & 3 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,18 @@ def can_read(self, timeout=0):
self.disconnect()
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")

def read_response(self, disable_decoding=False):
def read_response(self, disable_decoding=False, push_request=False):
"""Read the response from a previously sent command"""

host_error = self._host_error()

try:
response = self._parser.read_response(disable_decoding=disable_decoding)
if self.protocol == "3" and not HIREDIS_AVAILABLE:
response = self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
response = self._parser.read_response(disable_decoding=disable_decoding)
except socket.timeout:
self.disconnect()
raise TimeoutError(f"Timeout reading from {host_error}")
Expand Down Expand Up @@ -705,8 +710,9 @@ def _connect(self):
class UnixDomainSocketConnection(AbstractConnection):
"Manages UDS communication to and from a Redis server"

def __init__(self, path="", **kwargs):
def __init__(self, path="", socket_timeout=None, **kwargs):
self.path = path
self.socket_timeout = socket_timeout
super().__init__(**kwargs)

def repr_pieces(self):
Expand Down
81 changes: 74 additions & 7 deletions redis/parsers/resp3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import getLogger
from typing import Any, Union

from ..exceptions import ConnectionError, InvalidResponse, ResponseError
Expand All @@ -9,18 +10,29 @@
class _RESP3Parser(_RESPBase):
"""RESP3 protocol implementation"""

def read_response(self, disable_decoding=False):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response

def handle_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response

def read_response(self, disable_decoding=False, push_request=False):
pos = self._buffer.get_pos()
try:
result = self._read_response(disable_decoding=disable_decoding)
result = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
except BaseException:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result

def _read_response(self, disable_decoding=False):
def _read_response(self, disable_decoding=False, push_request=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -77,31 +89,64 @@ def _read_response(self, disable_decoding=False):
response = {
self._read_response(
disable_decoding=disable_decoding
): self._read_response(disable_decoding=disable_decoding)
): self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
}
# push response
elif byte == b">":
response = [
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func


class _AsyncRESP3Parser(_AsyncRESPBase):
async def read_response(self, disable_decoding: bool = False):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response

def handle_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response

async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(disable_decoding=disable_decoding)
response = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response

async def _read_response(
self, disable_decoding: bool = False
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -166,9 +211,31 @@ async def _read_response(
)
for _ in range(int(response))
}
# push response
elif byte == b">":
response = [
(
await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return await (
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
)
else:
return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func
14 changes: 14 additions & 0 deletions redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import contextmanager
from functools import wraps
from typing import Any, Dict, Mapping, Union
Expand Down Expand Up @@ -117,3 +118,16 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def _set_info_logger():
"""
Set up a logger that log info logs to stdout.
(This is used by the default push response handler)
"""
if "push_response" not in logging.root.manager.loggerDict.keys():
logger = logging.getLogger("push_response")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
Loading

0 comments on commit a96a38a

Please sign in to comment.