@@ -826,7 +826,7 @@ async def on_connect(self) -> None:
826826 if str_if_bytes (await self .read_response ()) != "OK" :
827827 raise ConnectionError ("Invalid Database" )
828828
829- async def disconnect (self ) -> None :
829+ async def disconnect (self , nowait : bool = False ) -> None :
830830 """Disconnects from the Redis server"""
831831 try :
832832 async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -836,8 +836,10 @@ async def disconnect(self) -> None:
836836 try :
837837 if os .getpid () == self .pid :
838838 self ._writer .close () # type: ignore[union-attr]
839+ # wait for close to finish, except when handling errors and
840+ # forcecully disconnecting.
839841 # py3.6 doesn't have this method
840- if hasattr (self ._writer , "wait_closed" ):
842+ if not nowait and hasattr (self ._writer , "wait_closed" ):
841843 await self ._writer .wait_closed () # type: ignore[union-attr]
842844 except OSError :
843845 pass
@@ -892,10 +894,10 @@ async def send_packed_command(
892894 self ._writer .writelines (command )
893895 await self ._writer .drain ()
894896 except asyncio .TimeoutError :
895- await self .disconnect ()
897+ await self .disconnect (nowait = True )
896898 raise TimeoutError ("Timeout writing to socket" ) from None
897899 except OSError as e :
898- await self .disconnect ()
900+ await self .disconnect (nowait = True )
899901 if len (e .args ) == 1 :
900902 err_no , errmsg = "UNKNOWN" , e .args [0 ]
901903 else :
@@ -907,7 +909,7 @@ async def send_packed_command(
907909 except asyncio .CancelledError :
908910 raise # is Exception and not BaseException in 3.7 and earlier
909911 except Exception :
910- await self .disconnect ()
912+ await self .disconnect (nowait = True )
911913 raise
912914
913915 async def send_command (self , * args : Any , ** kwargs : Any ) -> None :
@@ -923,7 +925,7 @@ async def can_read(self, timeout: float = 0):
923925 try :
924926 return await self ._parser .can_read (timeout )
925927 except OSError as e :
926- await self .disconnect ()
928+ await self .disconnect (nowait = True )
927929 raise ConnectionError (
928930 f"Error while reading from { self .host } :{ self .port } : { e .args } "
929931 )
@@ -942,17 +944,17 @@ async def read_response(self, disable_decoding: bool = False):
942944 disable_decoding = disable_decoding
943945 )
944946 except asyncio .TimeoutError :
945- await self .disconnect ()
947+ await self .disconnect (nowait = True )
946948 raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
947949 except OSError as e :
948- await self .disconnect ()
950+ await self .disconnect (nowait = True )
949951 raise ConnectionError (
950952 f"Error while reading from { self .host } :{ self .port } : { e .args } "
951953 )
952954 except asyncio .CancelledError :
953955 raise # is Exception and not BaseException in 3.7 and earlier
954956 except Exception :
955- await self .disconnect ()
957+ await self .disconnect (nowait = True )
956958 raise
957959
958960 if self .health_check_interval :
@@ -976,17 +978,17 @@ async def read_response_without_lock(self, disable_decoding: bool = False):
976978 disable_decoding = disable_decoding
977979 )
978980 except asyncio .TimeoutError :
979- await self .disconnect ()
981+ await self .disconnect (nowait = True )
980982 raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
981983 except OSError as e :
982- await self .disconnect ()
984+ await self .disconnect (nowait = True )
983985 raise ConnectionError (
984986 f"Error while reading from { self .host } :{ self .port } : { e .args } "
985987 )
986988 except asyncio .CancelledError :
987989 raise # is Exception and not BaseException in 3.7 and earlier
988990 except Exception :
989- await self .disconnect ()
991+ await self .disconnect (nowait = True )
990992 raise
991993
992994 if self .health_check_interval :
0 commit comments