11import itertools
22from contextlib import closing
3- from typing import Any , Generator , List , Optional , Tuple , Union
3+ from typing import Any , Callable , Dict , Generator , List , Optional , Tuple , Union
44
55CRNL = b"\r \n "
66
@@ -34,11 +34,11 @@ def __init__(self, code: str, value: str) -> None:
3434 def __repr__ (self ) -> str :
3535 return f"ErrorString({ self .code !r} , { super ().__repr__ ()} )"
3636
37- def __str__ (self ):
37+ def __str__ (self ) -> str :
3838 return f"{ self .code } { super ().__str__ ()} "
3939
4040
41- class PushData (list ):
41+ class PushData (List [ Any ] ):
4242 """
4343 A special type of list indicating data from a push response
4444 """
@@ -47,7 +47,7 @@ def __repr__(self) -> str:
4747 return f"PushData({ super ().__repr__ ()} )"
4848
4949
50- class Attribute (dict ):
50+ class Attribute (Dict [ Any , Any ] ):
5151 """
5252 A special type of map indicating data from a attribute response
5353 """
@@ -62,7 +62,7 @@ class RespEncoder:
6262 """
6363
6464 def __init__ (
65- self , protocol : int = 2 , encoding : str = "utf-8" , errorhander = "strict"
65+ self , protocol : int = 2 , encoding : str = "utf-8" , errorhander : str = "strict"
6666 ) -> None :
6767 self .protocol = protocol
6868 self .encoding = encoding
@@ -248,7 +248,7 @@ def parse(
248248 rest += incoming
249249 string = self .decode_bytes (rest [: (count + 4 )])
250250 if string [3 ] != ":" :
251- raise ValueError (f"Expected colon after hint, got { bulkstr [3 ]} " )
251+ raise ValueError (f"Expected colon after hint, got { string [3 ]} " )
252252 hint = string [:3 ]
253253 string = string [4 : (count + 4 )]
254254 yield VerbatimStr (string , hint ), rest [expect :]
@@ -310,8 +310,8 @@ def parse(
310310 # we decode them automatically
311311 decoded = self .decode_bytes (arg )
312312 assert isinstance (decoded , str )
313- code , value = decoded .split (" " , 1 )
314- yield ErrorStr (code , value ), rest
313+ err , value = decoded .split (" " , 1 )
314+ yield ErrorStr (err , value ), rest
315315
316316 elif code == b"!" : # resp3 error
317317 count = int (arg )
@@ -323,8 +323,8 @@ def parse(
323323 bulkstr = rest [:count ]
324324 decoded = self .decode_bytes (bulkstr )
325325 assert isinstance (decoded , str )
326- code , value = decoded .split (" " , 1 )
327- yield ErrorStr (code , value ), rest [expect :]
326+ err , value = decoded .split (" " , 1 )
327+ yield ErrorStr (err , value ), rest [expect :]
328328
329329 else :
330330 raise ValueError (f"Unknown opcode '{ code .decode ()} '" )
@@ -427,26 +427,26 @@ class RespServer:
427427 Accepts RESP commands and returns RESP responses.
428428 """
429429
430- handlers = {}
430+ handlers : Dict [ str , Callable [..., Any ]] = {}
431431
432- def __init__ (self ):
432+ def __init__ (self ) -> None :
433433 self .protocol = 2
434434 self .server_ver = self .get_server_version ()
435- self .auth = []
435+ self .auth : List [ Any ] = []
436436 self .client_name = ""
437437
438438 # patchable methods for testing
439439
440- def get_server_version (self ):
440+ def get_server_version (self ) -> int :
441441 return 6
442442
443- def on_auth (self , auth ) :
443+ def on_auth (self , auth : List [ Any ]) -> None :
444444 pass
445445
446- def on_setname (self , name ) :
446+ def on_setname (self , name : str ) -> None :
447447 pass
448448
449- def on_protocol (self , proto ) :
449+ def on_protocol (self , proto : int ) -> None :
450450 pass
451451
452452 def command (self , cmd : Any ) -> bytes :
@@ -466,7 +466,7 @@ def _command(self, cmd: Any) -> Any:
466466
467467 return ErrorStr ("ERR" , "unknown command {cmd!r}" )
468468
469- def handle_auth (self , args ) :
469+ def handle_auth (self , args : List [ Any ]) -> Union [ str , ErrorStr ] :
470470 self .auth = args [:]
471471 self .on_auth (self .auth )
472472 expect = 2 if self .server_ver >= 6 else 1
@@ -476,21 +476,21 @@ def handle_auth(self, args):
476476
477477 handlers ["AUTH" ] = handle_auth
478478
479- def handle_client (self , args ) :
479+ def handle_client (self , args : List [ Any ]) -> Union [ str , ErrorStr ] :
480480 if args [0 ] == "SETNAME" :
481481 return self .handle_setname (args [1 :])
482482 return ErrorStr ("ERR" , "unknown subcommand or wrong number of arguments" )
483483
484484 handlers ["CLIENT" ] = handle_client
485485
486- def handle_setname (self , args ) :
486+ def handle_setname (self , args : List [ Any ]) -> Union [ str , ErrorStr ] :
487487 if len (args ) != 1 :
488488 return ErrorStr ("ERR" , "wrong number of arguments" )
489489 self .client_name = args [0 ]
490490 self .on_setname (self .client_name )
491491 return "OK"
492492
493- def handle_hello (self , args ) :
493+ def handle_hello (self , args : List [ Any ]) -> Union [ ErrorStr , Dict [ str , Any ]] :
494494 if self .server_ver < 6 :
495495 return ErrorStr ("ERR" , "unknown command 'HELLO'" )
496496 proto = self .protocol
@@ -507,14 +507,14 @@ def handle_hello(self, args):
507507 auth_args = args [:2 ]
508508 args = args [2 :]
509509 res = self .handle_auth (auth_args )
510- if res != "OK" :
510+ if isinstance ( res , ErrorStr ) :
511511 return res
512512 continue
513513 if cmd == "SETNAME" :
514514 setname_args = args [:1 ]
515515 args = args [1 :]
516516 res = self .handle_setname (setname_args )
517- if res != "OK" :
517+ if isinstance ( res , ErrorStr ) :
518518 return res
519519 continue
520520 return ErrorStr ("ERR" , "unknown subcommand or wrong number of arguments" )
0 commit comments