diff --git a/tests/ssl_utils.py b/tests/ssl_utils.py new file mode 100644 index 0000000000..50937638a7 --- /dev/null +++ b/tests/ssl_utils.py @@ -0,0 +1,14 @@ +import os + + +def get_ssl_filename(name): + root = os.path.join(os.path.dirname(__file__), "..") + cert_dir = os.path.abspath(os.path.join(root, "docker", "stunnel", "keys")) + if not os.path.isdir(cert_dir): # github actions package validation case + cert_dir = os.path.abspath( + os.path.join(root, "..", "docker", "stunnel", "keys") + ) + if not os.path.isdir(cert_dir): + raise IOError(f"No SSL certificates found. They should be in {cert_dir}") + + return os.path.join(cert_dir, name) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 17aa879b0f..c41d4a2168 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,7 +1,6 @@ import asyncio import binascii import datetime -import os import warnings from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union from urllib.parse import urlparse @@ -36,6 +35,7 @@ skip_unless_arch_bits, ) +from ..ssl_utils import get_ssl_filename from .compat import mock pytestmark = pytest.mark.onlycluster @@ -2744,17 +2744,8 @@ class TestSSL: appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "../..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "docker", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") @pytest_asyncio.fixture() def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py new file mode 100644 index 0000000000..8e3209fdc6 --- /dev/null +++ b/tests/test_asyncio/test_connect.py @@ -0,0 +1,145 @@ +import asyncio +import logging +import re +import socket +import ssl + +import pytest + +from redis.asyncio.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, +) + +from ..ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +async def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + await _assert_connect(conn, tcp_address) + + +async def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection( + path=path, client_name=_CLIENT_NAME, socket_timeout=10 + ) + await _assert_connect(conn, path) + + +@pytest.mark.ssl +async def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +async def _assert_connect(conn, server_address, certfile=None, keyfile=None): + stop_event = asyncio.Event() + finished = asyncio.Event() + + async def _handler(reader, writer): + try: + return await _redis_request_handler(reader, writer, stop_event) + finally: + finished.set() + + if isinstance(server_address, str): + server = await asyncio.start_unix_server(_handler, path=server_address) + elif certfile: + host, port = server_address + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.load_cert_chain(certfile=certfile, keyfile=keyfile) + server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) + else: + host, port = server_address + server = await asyncio.start_server(_handler, host=host, port=port) + + async with server as aserver: + await aserver.start_serving() + try: + await conn.connect() + await conn.disconnect() + finally: + stop_event.set() + aserver.close() + await aserver.wait_closed() + await finished.wait() + + +async def _redis_request_handler(reader, writer, stop_event): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while not stop_event.is_set() or buffer: + _logger.info(str(stop_event.is_set())) + try: + buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) + except TimeoutError: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response from %s", resp) + writer.write(resp) + await writer.drain() + command = None + _logger.info("Exit handler") diff --git a/tests/test_connect.py b/tests/test_connect.py new file mode 100644 index 0000000000..b4ec7020e1 --- /dev/null +++ b/tests/test_connect.py @@ -0,0 +1,185 @@ +import logging +import re +import socket +import socketserver +import ssl +import threading + +import pytest + +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection + +from .ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, tcp_address) + + +def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, path) + + +@pytest.mark.ssl +def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +def _assert_connect(conn, server_address, certfile=None, keyfile=None): + if isinstance(server_address, str): + server = _RedisUDSServer(server_address, _RedisRequestHandler) + else: + server = _RedisTCPServer( + server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile + ) + with server as aserver: + t = threading.Thread(target=aserver.serve_forever) + t.start() + try: + aserver.wait_online() + conn.connect() + conn.disconnect() + finally: + aserver.stop() + t.join(timeout=5) + + +class _RedisTCPServer(socketserver.TCPServer): + def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + self._certfile = certfile + self._keyfile = keyfile + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + + def get_request(self): + if self._certfile is None: + return super().get_request() + newsocket, fromaddr = self.socket.accept() + connstream = ssl.wrap_socket( + newsocket, + server_side=True, + certfile=self._certfile, + keyfile=self._keyfile, + ssl_version=ssl.PROTOCOL_TLSv1_2, + ) + return connstream, fromaddr + + +class _RedisUDSServer(socketserver.UnixStreamServer): + def __init__(self, *args, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + + +class _RedisRequestHandler(socketserver.StreamRequestHandler): + def setup(self): + _logger.info("%s connected", self.client_address) + + def finish(self): + _logger.info("%s disconnected", self.client_address) + + def handle(self): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while self.server.is_serving() or buffer: + try: + buffer += self.request.recv(1024) + except socket.timeout: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response %s", resp) + self.request.sendall(resp) + command = None + _logger.info("Exit handler") diff --git a/tests/test_ssl.py b/tests/test_ssl.py index ed38a3166b..c1a981d310 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,4 +1,3 @@ -import os import socket import ssl from urllib.parse import urlparse @@ -9,6 +8,7 @@ from redis.exceptions import ConnectionError, RedisError from .conftest import skip_if_cryptography, skip_if_nocryptography +from .ssl_utils import get_ssl_filename @pytest.mark.ssl @@ -19,17 +19,8 @@ class TestSSL: and connecting to the appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "docker", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") def test_ssl_with_invalid_cert(self, request): ssl_url = request.config.option.redis_ssl_url