@@ -839,7 +839,7 @@ async def on_connect(self) -> None:
839
839
if str_if_bytes (await self .read_response ()) != "OK" :
840
840
raise ConnectionError ("Invalid Database" )
841
841
842
- async def disconnect (self ) -> None :
842
+ async def disconnect (self , nowait : bool = False ) -> None :
843
843
"""Disconnects from the Redis server"""
844
844
try :
845
845
async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -849,8 +849,9 @@ async def disconnect(self) -> None:
849
849
try :
850
850
if os .getpid () == self .pid :
851
851
self ._writer .close () # type: ignore[union-attr]
852
- # py3.6 doesn't have this method
853
- if hasattr (self ._writer , "wait_closed" ):
852
+ # wait for close to finish, except when handling errors and
853
+ # forcecully disconnecting.
854
+ if not nowait :
854
855
await self ._writer .wait_closed () # type: ignore[union-attr]
855
856
except OSError :
856
857
pass
@@ -905,10 +906,10 @@ async def send_packed_command(
905
906
self ._writer .writelines (command )
906
907
await self ._writer .drain ()
907
908
except asyncio .TimeoutError :
908
- await self .disconnect ()
909
+ await self .disconnect (nowait = True )
909
910
raise TimeoutError ("Timeout writing to socket" ) from None
910
911
except OSError as e :
911
- await self .disconnect ()
912
+ await self .disconnect (nowait = True )
912
913
if len (e .args ) == 1 :
913
914
err_no , errmsg = "UNKNOWN" , e .args [0 ]
914
915
else :
@@ -918,7 +919,7 @@ async def send_packed_command(
918
919
f"Error { err_no } while writing to socket. { errmsg } ."
919
920
) from e
920
921
except BaseException :
921
- await self .disconnect ()
922
+ await self .disconnect (nowait = True )
922
923
raise
923
924
924
925
async def send_command (self , * args : Any , ** kwargs : Any ) -> None :
@@ -934,7 +935,7 @@ async def can_read(self, timeout: float = 0):
934
935
try :
935
936
return await self ._parser .can_read (timeout )
936
937
except OSError as e :
937
- await self .disconnect ()
938
+ await self .disconnect (nowait = True )
938
939
raise ConnectionError (
939
940
f"Error while reading from { self .host } :{ self .port } : { e .args } "
940
941
)
@@ -985,15 +986,15 @@ async def read_response_without_lock(self, disable_decoding: bool = False):
985
986
disable_decoding = disable_decoding
986
987
)
987
988
except asyncio .TimeoutError :
988
- await self .disconnect ()
989
+ await self .disconnect (nowait = True )
989
990
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
990
991
except OSError as e :
991
- await self .disconnect ()
992
+ await self .disconnect (nowait = True )
992
993
raise ConnectionError (
993
994
f"Error while reading from { self .host } :{ self .port } : { e .args } "
994
995
)
995
996
except BaseException :
996
- await self .disconnect ()
997
+ await self .disconnect (nowait = True )
997
998
raise
998
999
999
1000
if self .health_check_interval :
0 commit comments