Skip to content

Commit 1c203e0

Browse files
authored
fix: Support TcpFakeServer for Python 3.9 (#411)
1 parent 9c0e8bb commit 1c203e0

File tree

3 files changed

+31
-41
lines changed

3 files changed

+31
-41
lines changed

fakeredis/__init__.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import sys
2-
31
from . import _typing
42
from ._connection import (
53
FakeRedis,
@@ -11,14 +9,7 @@
119
FakeRedis as FakeAsyncRedis,
1210
FakeConnection as FakeAsyncConnection,
1311
)
14-
15-
if sys.version_info >= (3, 11):
16-
from ._tcp_server import TcpFakeServer
17-
else:
18-
19-
class TcpFakeServer:
20-
def __init__(self, *args, **kwargs):
21-
raise NotImplementedError("TcpFakeServer is only available in Python 3.11+")
12+
from ._tcp_server import TcpFakeServer
2213

2314

2415
__version__ = _typing.lib_version

fakeredis/_tcp_server.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22
from dataclasses import dataclass
3+
from io import BufferedIOBase
34
from itertools import count
45
from socketserver import ThreadingTCPServer, StreamRequestHandler
5-
from typing import BinaryIO, Dict, Tuple, Any
6+
from typing import Dict, Tuple, Any, Union
67

78
from fakeredis import FakeRedis
89
from fakeredis import FakeServer
@@ -13,7 +14,7 @@
1314
# logging.basicConfig(level=logging.DEBUG)
1415

1516

16-
def to_bytes(value) -> bytes:
17+
def to_bytes(value: Any) -> bytes:
1718
if isinstance(value, bytes):
1819
return value
1920
return str(value).encode()
@@ -27,37 +28,36 @@ class Client:
2728

2829
@dataclass
2930
class Reader:
30-
reader: BinaryIO
31+
reader: BufferedIOBase
3132

3233
def load(self) -> Any:
3334
line = self.reader.readline().strip()
34-
match line[0:1], line[1:]:
35-
case b"*", length:
36-
length = int(length)
37-
array = [None] * length
38-
for i in range(length):
39-
array[i] = self.load()
40-
return array
41-
case b"$", length:
42-
bulk_string = self.reader.read(int(length) + 2).strip()
43-
if len(bulk_string) != int(length):
44-
raise ValueError()
45-
return bulk_string
46-
case b":", value:
47-
return int(value)
48-
case b"+", value:
49-
return value
50-
case b"-", value:
51-
return Exception(value)
52-
case _:
53-
return None
35+
prefix, rest = line[0:1], line[1:]
36+
if prefix == b"*":
37+
length = int(rest)
38+
array = [None] * length
39+
for i in range(length):
40+
array[i] = self.load()
41+
return array
42+
if prefix == b"$":
43+
bulk_string = self.reader.read(int(rest) + 2).strip()
44+
if len(bulk_string) != int(rest):
45+
raise ValueError()
46+
return bulk_string
47+
if prefix == b":":
48+
return int(rest)
49+
if prefix == b"+":
50+
return rest
51+
if prefix == b"-":
52+
return Exception(rest)
53+
return None
5454

5555

5656
@dataclass
5757
class Writer:
58-
writer: BinaryIO
58+
writer: BufferedIOBase
5959

60-
def dump(self, value: Any, dump_bulk=False) -> None:
60+
def dump(self, value: Any, dump_bulk: bool = False) -> None:
6161
if isinstance(value, int):
6262
self.writer.write(f":{value}\r\n".encode())
6363
elif isinstance(value, (str, bytes)):
@@ -77,6 +77,8 @@ def dump(self, value: Any, dump_bulk=False) -> None:
7777

7878

7979
class TCPFakeRequestHandler(StreamRequestHandler):
80+
server: "TcpFakeServer" # type: ignore
81+
8082
def setup(self) -> None:
8183
super().setup()
8284
if self.client_address in self.server.clients:
@@ -90,7 +92,7 @@ def setup(self) -> None:
9092
self.writer = Writer(self.wfile)
9193
self.server.clients[self.client_address] = self.current_client
9294

93-
def handle(self):
95+
def handle(self) -> None:
9496
LOGGER.debug(f"+++ {self.client_address[0]} connected")
9597
while True:
9698
try:
@@ -117,7 +119,7 @@ def finish(self) -> None:
117119
class TcpFakeServer(ThreadingTCPServer):
118120
def __init__(
119121
self,
120-
server_address: Tuple[str | bytes | bytearray, int],
122+
server_address: Tuple[Union[str, bytes, bytearray], int],
121123
bind_and_activate: bool = True,
122124
server_type: ServerType = "redis",
123125
server_version: VersionType = (8, 0),
@@ -126,7 +128,7 @@ def __init__(
126128
self.allow_reuse_address = True
127129
self.fake_server = FakeServer(server_type=server_type, version=server_version)
128130
self.client_ids = count(0)
129-
self.clients: Dict[int, FakeRedis] = dict()
131+
self.clients: Dict[int, Client] = dict()
130132

131133

132134
if __name__ == "__main__":

test/test_tcp_server/test_connectivity.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import sys
21
import time
32
from threading import Thread
43

5-
import pytest
64
import redis
75

86
from fakeredis import TcpFakeServer
97

108

11-
@pytest.mark.skipif(sys.version_info < (3, 11), reason="TcpFakeServer is only available in Python 3.11+")
129
def test_tcp_server_started():
1310
server_address = ("127.0.0.1", 19000)
1411
server = TcpFakeServer(server_address)

0 commit comments

Comments
 (0)