diff --git a/CHANGES b/CHANGES index 8f2017218a..3865ed1067 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Add `address_remap` parameter to `RedisCluster` * Fix incorrect usage of once flag in async Sentinel * asyncio: Fix memory leak caused by hiredis (#2693) * Allow data to drain from async PythonParser when reading during a disconnect() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e31ec3491e..8518547518 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,15 +2,15 @@ ## Introduction -First off, thank you for considering contributing to redis-py. We value -community contributions! +We appreciate your interest in considering contributing to redis-py. +Community contributions mean a lot to us. -## Contributions We Need +## Contributions we need -You may already know what you want to contribute \-- a fix for a bug you +You may already know how you'd like to contribute, whether it's a fix for a bug you encountered, or a new feature your team wants to use. -If you don't know what to contribute, keep an open mind! Improving +If you don't know where to start, consider improving documentation, bug triaging, and writing tutorials are all examples of helpful contributions that mean less work for you. @@ -166,19 +166,19 @@ When filing an issue, make sure to answer these five questions: 4. What did you expect to see? 5. What did you see instead? -## How to Suggest a Feature or Enhancement +## Suggest a feature or enhancement If you'd like to contribute a new feature, make sure you check our issue list to see if someone has already proposed it. Work may already -be under way on the feature you want -- or we may have rejected a +be underway on the feature you want or we may have rejected a feature like it already. If you don't see anything, open a new issue that describes the feature you would like and how it should work. -## Code Review Process +## Code review process -The core team looks at Pull Requests on a regular basis. We will give -feedback as as soon as possible. After feedback, we expect a response +The core team regularly looks at pull requests. We will provide +feedback as as soon as possible. After receiving our feedback, please respond within two weeks. After that time, we may close your PR if it isn't showing any activity. diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index a4a9561cf1..eb5f4db061 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -5,12 +5,14 @@ import warnings from typing import ( Any, + Callable, Deque, Dict, Generator, List, Mapping, Optional, + Tuple, Type, TypeVar, Union, @@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand maximum number of connections are already created, a :class:`~.MaxConnectionsError` is raised. This error may be retried as defined by :attr:`connection_error_retry_attempts` + :param address_remap: + | An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. | Rest of the arguments will be passed to the :class:`~redis.asyncio.connection.Connection` instances when created @@ -250,6 +258,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -337,7 +346,12 @@ def __init__( if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.nodes_manager = NodesManager( + startup_nodes, + require_full_coverage, + kwargs, + address_remap=address_remap, + ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps @@ -1059,6 +1073,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", + "address_remap", ) def __init__( @@ -1066,10 +1081,12 @@ def __init__( startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs + self.address_remap = address_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1228,6 +1245,7 @@ async def initialize(self) -> None: if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: @@ -1246,6 +1264,7 @@ async def initialize(self) -> None: for replica_node in replica_nodes: host = replica_node[0] port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = tmp_nodes_cache.get( get_node_name(host, port) @@ -1319,6 +1338,16 @@ async def close(self, attr: str = "nodes_cache") -> None: ) ) + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ diff --git a/redis/cluster.py b/redis/cluster.py index 5e6e7da546..3ecc2dab56 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -466,6 +466,7 @@ def __init__( read_from_replicas: bool = False, dynamic_startup_nodes: bool = True, url: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): """ @@ -514,6 +515,12 @@ def __init__( reinitialize_steps to 1. To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. + :param address_remap: + An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. :**kwargs: Extra arguments that will be sent into Redis instance when created @@ -594,6 +601,7 @@ def __init__( from_url=from_url, require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, + address_remap=address_remap, **kwargs, ) @@ -1269,6 +1277,7 @@ def __init__( lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): self.nodes_cache = {} @@ -1280,6 +1289,7 @@ def __init__( self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class + self.address_remap = address_remap self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1502,6 +1512,7 @@ def initialize(self): if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = self._get_or_create_cluster_node( host, port, PRIMARY, tmp_nodes_cache @@ -1518,6 +1529,7 @@ def initialize(self): for replica_node in replica_nodes: host = str_if_bytes(replica_node[0]) port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = self._get_or_create_cluster_node( host, port, REPLICA, tmp_nodes_cache @@ -1591,6 +1603,16 @@ def reset(self): # The read_load_balancer is None, do nothing pass + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPubSub(PubSub): """ diff --git a/tests/ssl_utils.py b/tests/ssl_utils.py new file mode 100644 index 0000000000..50937638a7 --- /dev/null +++ b/tests/ssl_utils.py @@ -0,0 +1,14 @@ +import os + + +def get_ssl_filename(name): + root = os.path.join(os.path.dirname(__file__), "..") + cert_dir = os.path.abspath(os.path.join(root, "docker", "stunnel", "keys")) + if not os.path.isdir(cert_dir): # github actions package validation case + cert_dir = os.path.abspath( + os.path.join(root, "..", "docker", "stunnel", "keys") + ) + if not os.path.isdir(cert_dir): + raise IOError(f"No SSL certificates found. They should be in {cert_dir}") + + return os.path.join(cert_dir, name) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 13e5e26ae3..afe9e23b78 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,7 +1,6 @@ import asyncio import binascii import datetime -import os import warnings from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union from urllib.parse import urlparse @@ -11,7 +10,7 @@ from _pytest.fixtures import FixtureRequest from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster -from redis.asyncio.connection import Connection, SSLConnection +from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff @@ -36,6 +35,7 @@ skip_unless_arch_bits, ) +from ..ssl_utils import get_ssl_filename from .compat import mock pytestmark = pytest.mark.onlycluster @@ -49,6 +49,71 @@ ] +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.send_event = asyncio.Event() + self.server = None + self.task = None + self.n_connections = 0 + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + self.n_connections += 1 + pipe1 = asyncio.create_task(self.pipe(reader, redis_writer)) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer)) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def aclose(self): + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + await self.server.wait_closed() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + while True: + data = await reader.read(1000) + if not data: + break + writer.write(data) + await writer.drain() + + +@pytest.fixture +def redis_addr(request): + redis_url = request.config.getoption("--redis-url") + scheme, netloc = urlparse(redis_url)[:2] + assert scheme == "redis" + if ":" in netloc: + host, port = netloc.split(":") + return host, int(port) + else: + return netloc, 6379 + + @pytest_asyncio.fixture() async def slowlog(r: RedisCluster) -> None: """ @@ -809,6 +874,49 @@ async def test_default_node_is_replaced_after_exception(self, r): # Rollback to the old default node r.replace_default_node(curr_default_node) + async def test_address_remap(self, create_redis, redis_addr): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + ports = [redis_addr[1] + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) + for port in ports + ] + await asyncio.gather(*[p.start() for p in proxies]) + try: + # create cluster: + r = await create_redis( + cls=RedisCluster, flushdb=False, address_remap=address_remap + ) + try: + assert await r.ping() is True + assert await r.set("byte_string", b"giraffe") + assert await r.get("byte_string") == b"giraffe" + finally: + await r.close() + finally: + await asyncio.gather(*[p.aclose() for p in proxies]) + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + class TestClusterRedisCommands: """ @@ -2641,17 +2749,8 @@ class TestSSL: appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "../..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "docker", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") @pytest_asyncio.fixture() def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py new file mode 100644 index 0000000000..8e3209fdc6 --- /dev/null +++ b/tests/test_asyncio/test_connect.py @@ -0,0 +1,145 @@ +import asyncio +import logging +import re +import socket +import ssl + +import pytest + +from redis.asyncio.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, +) + +from ..ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +async def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + await _assert_connect(conn, tcp_address) + + +async def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection( + path=path, client_name=_CLIENT_NAME, socket_timeout=10 + ) + await _assert_connect(conn, path) + + +@pytest.mark.ssl +async def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +async def _assert_connect(conn, server_address, certfile=None, keyfile=None): + stop_event = asyncio.Event() + finished = asyncio.Event() + + async def _handler(reader, writer): + try: + return await _redis_request_handler(reader, writer, stop_event) + finally: + finished.set() + + if isinstance(server_address, str): + server = await asyncio.start_unix_server(_handler, path=server_address) + elif certfile: + host, port = server_address + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.load_cert_chain(certfile=certfile, keyfile=keyfile) + server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) + else: + host, port = server_address + server = await asyncio.start_server(_handler, host=host, port=port) + + async with server as aserver: + await aserver.start_serving() + try: + await conn.connect() + await conn.disconnect() + finally: + stop_event.set() + aserver.close() + await aserver.wait_closed() + await finished.wait() + + +async def _redis_request_handler(reader, writer, stop_event): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while not stop_event.is_set() or buffer: + _logger.info(str(stop_event.is_set())) + try: + buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) + except TimeoutError: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response from %s", resp) + writer.write(resp) + await writer.drain() + command = None + _logger.info("Exit handler") diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 58f9b77d7d..1f037c9edf 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,9 +1,14 @@ import binascii import datetime +import select +import socket +import socketserver +import threading import warnings from queue import LifoQueue, Queue from time import sleep from unittest.mock import DEFAULT, Mock, call, patch +from urllib.parse import urlparse import pytest @@ -53,6 +58,85 @@ ] +class ProxyRequestHandler(socketserver.BaseRequestHandler): + def recv(self, sock): + """A recv with a timeout""" + r = select.select([sock], [], [], 0.01) + if not r[0]: + return None + return sock.recv(1000) + + def handle(self): + self.server.proxy.n_connections += 1 + conn = socket.create_connection(self.server.proxy.redis_addr) + stop = False + + def from_server(): + # read from server and pass to client + while not stop: + data = self.recv(conn) + if data is None: + continue + if not data: + self.request.shutdown(socket.SHUT_WR) + return + self.request.sendall(data) + + thread = threading.Thread(target=from_server) + thread.start() + try: + while True: + # read from client and send to server + data = self.request.recv(1000) + if not data: + return + conn.sendall(data) + finally: + conn.shutdown(socket.SHUT_WR) + stop = True # for safety + thread.join() + conn.close() + + +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler) + self.server.proxy = self + self.server.socket_reuse_address = True + self.thread = None + self.n_connections = 0 + + def start(self): + # test that we can connect to redis + s = socket.create_connection(self.redis_addr, timeout=2) + s.close() + # Start a thread with the server -- that thread will then start one + # more thread for each request + self.thread = threading.Thread(target=self.server.serve_forever) + # Exit the server thread when the main thread terminates + self.thread.daemon = True + self.thread.start() + + def close(self): + self.server.shutdown() + + +@pytest.fixture +def redis_addr(request): + redis_url = request.config.getoption("--redis-url") + scheme, netloc = urlparse(redis_url)[:2] + assert scheme == "redis" + if ":" in netloc: + host, port = netloc.split(":") + return host, int(port) + else: + return netloc, 6379 + + @pytest.fixture() def slowlog(request, r): """ @@ -823,6 +907,51 @@ def raise_connection_error(): assert "myself" not in nodes.get(curr_default_node.name).get("flags") assert r.get_default_node() != curr_default_node + def test_address_remap(self, request, redis_addr): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + ports = [redis_addr[1] + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) + for port in ports + ] + for p in proxies: + p.start() + try: + # create cluster: + r = _get_client( + RedisCluster, request, flushdb=False, address_remap=address_remap + ) + try: + assert r.ping() is True + assert r.set("byte_string", b"giraffe") + assert r.get("byte_string") == b"giraffe" + finally: + r.close() + finally: + for p in proxies: + p.close() + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + @pytest.mark.onlycluster class TestClusterRedisCommands: diff --git a/tests/test_connect.py b/tests/test_connect.py new file mode 100644 index 0000000000..b4ec7020e1 --- /dev/null +++ b/tests/test_connect.py @@ -0,0 +1,185 @@ +import logging +import re +import socket +import socketserver +import ssl +import threading + +import pytest + +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection + +from .ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, tcp_address) + + +def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, path) + + +@pytest.mark.ssl +def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +def _assert_connect(conn, server_address, certfile=None, keyfile=None): + if isinstance(server_address, str): + server = _RedisUDSServer(server_address, _RedisRequestHandler) + else: + server = _RedisTCPServer( + server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile + ) + with server as aserver: + t = threading.Thread(target=aserver.serve_forever) + t.start() + try: + aserver.wait_online() + conn.connect() + conn.disconnect() + finally: + aserver.stop() + t.join(timeout=5) + + +class _RedisTCPServer(socketserver.TCPServer): + def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + self._certfile = certfile + self._keyfile = keyfile + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + + def get_request(self): + if self._certfile is None: + return super().get_request() + newsocket, fromaddr = self.socket.accept() + connstream = ssl.wrap_socket( + newsocket, + server_side=True, + certfile=self._certfile, + keyfile=self._keyfile, + ssl_version=ssl.PROTOCOL_TLSv1_2, + ) + return connstream, fromaddr + + +class _RedisUDSServer(socketserver.UnixStreamServer): + def __init__(self, *args, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + + +class _RedisRequestHandler(socketserver.StreamRequestHandler): + def setup(self): + _logger.info("%s connected", self.client_address) + + def finish(self): + _logger.info("%s disconnected", self.client_address) + + def handle(self): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while self.server.is_serving() or buffer: + try: + buffer += self.request.recv(1024) + except socket.timeout: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response %s", resp) + self.request.sendall(resp) + command = None + _logger.info("Exit handler") diff --git a/tests/test_ssl.py b/tests/test_ssl.py index ed38a3166b..c1a981d310 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,4 +1,3 @@ -import os import socket import ssl from urllib.parse import urlparse @@ -9,6 +8,7 @@ from redis.exceptions import ConnectionError, RedisError from .conftest import skip_if_cryptography, skip_if_nocryptography +from .ssl_utils import get_ssl_filename @pytest.mark.ssl @@ -19,17 +19,8 @@ class TestSSL: and connecting to the appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "docker", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") def test_ssl_with_invalid_cert(self, request): ssl_url = request.config.option.redis_ssl_url