11import logging
22from dataclasses import dataclass
3+ from io import BufferedIOBase
34from itertools import count
45from socketserver import ThreadingTCPServer , StreamRequestHandler
5- from typing import BinaryIO , Dict , Tuple , Any
6+ from typing import Dict , Tuple , Any , Union
67
78from fakeredis import FakeRedis
89from fakeredis import FakeServer
1314# logging.basicConfig(level=logging.DEBUG)
1415
1516
16- def to_bytes (value ) -> bytes :
17+ def to_bytes (value : Any ) -> bytes :
1718 if isinstance (value , bytes ):
1819 return value
1920 return str (value ).encode ()
@@ -27,37 +28,36 @@ class Client:
2728
2829@dataclass
2930class Reader :
30- reader : BinaryIO
31+ reader : BufferedIOBase
3132
3233 def load (self ) -> Any :
3334 line = self .reader .readline ().strip ()
34- match line [0 :1 ], line [1 :]:
35- case b"*" , length :
36- length = int (length )
37- array = [None ] * length
38- for i in range (length ):
39- array [i ] = self .load ()
40- return array
41- case b"$" , length :
42- bulk_string = self .reader .read (int (length ) + 2 ).strip ()
43- if len (bulk_string ) != int (length ):
44- raise ValueError ()
45- return bulk_string
46- case b":" , value :
47- return int (value )
48- case b"+" , value :
49- return value
50- case b"-" , value :
51- return Exception (value )
52- case _:
53- return None
35+ prefix , rest = line [0 :1 ], line [1 :]
36+ if prefix == b"*" :
37+ length = int (rest )
38+ array = [None ] * length
39+ for i in range (length ):
40+ array [i ] = self .load ()
41+ return array
42+ if prefix == b"$" :
43+ bulk_string = self .reader .read (int (rest ) + 2 ).strip ()
44+ if len (bulk_string ) != int (rest ):
45+ raise ValueError ()
46+ return bulk_string
47+ if prefix == b":" :
48+ return int (rest )
49+ if prefix == b"+" :
50+ return rest
51+ if prefix == b"-" :
52+ return Exception (rest )
53+ return None
5454
5555
5656@dataclass
5757class Writer :
58- writer : BinaryIO
58+ writer : BufferedIOBase
5959
60- def dump (self , value : Any , dump_bulk = False ) -> None :
60+ def dump (self , value : Any , dump_bulk : bool = False ) -> None :
6161 if isinstance (value , int ):
6262 self .writer .write (f":{ value } \r \n " .encode ())
6363 elif isinstance (value , (str , bytes )):
@@ -77,6 +77,8 @@ def dump(self, value: Any, dump_bulk=False) -> None:
7777
7878
7979class TCPFakeRequestHandler (StreamRequestHandler ):
80+ server : "TcpFakeServer" # type: ignore
81+
8082 def setup (self ) -> None :
8183 super ().setup ()
8284 if self .client_address in self .server .clients :
@@ -90,7 +92,7 @@ def setup(self) -> None:
9092 self .writer = Writer (self .wfile )
9193 self .server .clients [self .client_address ] = self .current_client
9294
93- def handle (self ):
95+ def handle (self ) -> None :
9496 LOGGER .debug (f"+++ { self .client_address [0 ]} connected" )
9597 while True :
9698 try :
@@ -117,7 +119,7 @@ def finish(self) -> None:
117119class TcpFakeServer (ThreadingTCPServer ):
118120 def __init__ (
119121 self ,
120- server_address : Tuple [str | bytes | bytearray , int ],
122+ server_address : Tuple [Union [ str , bytes , bytearray ] , int ],
121123 bind_and_activate : bool = True ,
122124 server_type : ServerType = "redis" ,
123125 server_version : VersionType = (8 , 0 ),
@@ -126,7 +128,7 @@ def __init__(
126128 self .allow_reuse_address = True
127129 self .fake_server = FakeServer (server_type = server_type , version = server_version )
128130 self .client_ids = count (0 )
129- self .clients : Dict [int , FakeRedis ] = dict ()
131+ self .clients : Dict [int , Client ] = dict ()
130132
131133
132134if __name__ == "__main__" :
0 commit comments