@@ -826,7 +826,7 @@ async def on_connect(self) -> None:
826
826
if str_if_bytes (await self .read_response ()) != "OK" :
827
827
raise ConnectionError ("Invalid Database" )
828
828
829
- async def disconnect (self ) -> None :
829
+ async def disconnect (self , nowait : bool = False ) -> None :
830
830
"""Disconnects from the Redis server"""
831
831
try :
832
832
async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -836,8 +836,10 @@ async def disconnect(self) -> None:
836
836
try :
837
837
if os .getpid () == self .pid :
838
838
self ._writer .close () # type: ignore[union-attr]
839
+ # wait for close to finish, except when handling errors and
840
+ # forcecully disconnecting.
839
841
# py3.6 doesn't have this method
840
- if hasattr (self ._writer , "wait_closed" ):
842
+ if not nowait and hasattr (self ._writer , "wait_closed" ):
841
843
await self ._writer .wait_closed () # type: ignore[union-attr]
842
844
except OSError :
843
845
pass
@@ -892,10 +894,10 @@ async def send_packed_command(
892
894
self ._writer .writelines (command )
893
895
await self ._writer .drain ()
894
896
except asyncio .TimeoutError :
895
- await self .disconnect ()
897
+ await self .disconnect (nowait = True )
896
898
raise TimeoutError ("Timeout writing to socket" ) from None
897
899
except OSError as e :
898
- await self .disconnect ()
900
+ await self .disconnect (nowait = True )
899
901
if len (e .args ) == 1 :
900
902
err_no , errmsg = "UNKNOWN" , e .args [0 ]
901
903
else :
@@ -907,7 +909,7 @@ async def send_packed_command(
907
909
except asyncio .CancelledError :
908
910
raise # is Exception and not BaseException in 3.7 and earlier
909
911
except Exception :
910
- await self .disconnect ()
912
+ await self .disconnect (nowait = True )
911
913
raise
912
914
913
915
async def send_command (self , * args : Any , ** kwargs : Any ) -> None :
@@ -923,7 +925,7 @@ async def can_read(self, timeout: float = 0):
923
925
try :
924
926
return await self ._parser .can_read (timeout )
925
927
except OSError as e :
926
- await self .disconnect ()
928
+ await self .disconnect (nowait = True )
927
929
raise ConnectionError (
928
930
f"Error while reading from { self .host } :{ self .port } : { e .args } "
929
931
)
@@ -942,17 +944,17 @@ async def read_response(self, disable_decoding: bool = False):
942
944
disable_decoding = disable_decoding
943
945
)
944
946
except asyncio .TimeoutError :
945
- await self .disconnect ()
947
+ await self .disconnect (nowait = True )
946
948
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
947
949
except OSError as e :
948
- await self .disconnect ()
950
+ await self .disconnect (nowait = True )
949
951
raise ConnectionError (
950
952
f"Error while reading from { self .host } :{ self .port } : { e .args } "
951
953
)
952
954
except asyncio .CancelledError :
953
955
raise # is Exception and not BaseException in 3.7 and earlier
954
956
except Exception :
955
- await self .disconnect ()
957
+ await self .disconnect (nowait = True )
956
958
raise
957
959
958
960
if self .health_check_interval :
@@ -976,17 +978,17 @@ async def read_response_without_lock(self, disable_decoding: bool = False):
976
978
disable_decoding = disable_decoding
977
979
)
978
980
except asyncio .TimeoutError :
979
- await self .disconnect ()
981
+ await self .disconnect (nowait = True )
980
982
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
981
983
except OSError as e :
982
- await self .disconnect ()
984
+ await self .disconnect (nowait = True )
983
985
raise ConnectionError (
984
986
f"Error while reading from { self .host } :{ self .port } : { e .args } "
985
987
)
986
988
except asyncio .CancelledError :
987
989
raise # is Exception and not BaseException in 3.7 and earlier
988
990
except Exception :
989
- await self .disconnect ()
991
+ await self .disconnect (nowait = True )
990
992
raise
991
993
992
994
if self .health_check_interval :
0 commit comments