@@ -815,7 +815,7 @@ async def on_connect(self) -> None:
815
815
if str_if_bytes (await self .read_response ()) != "OK" :
816
816
raise ConnectionError ("Invalid Database" )
817
817
818
- async def disconnect (self ) -> None :
818
+ async def disconnect (self , nowait : bool = False ) -> None :
819
819
"""Disconnects from the Redis server"""
820
820
try :
821
821
async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -825,8 +825,9 @@ async def disconnect(self) -> None:
825
825
try :
826
826
if os .getpid () == self .pid :
827
827
self ._writer .close () # type: ignore[union-attr]
828
- # py3.6 doesn't have this method
829
- if hasattr (self ._writer , "wait_closed" ):
828
+ # wait for close to finish, except when handling errors and
829
+ # forcecully disconnecting.
830
+ if not nowait :
830
831
await self ._writer .wait_closed () # type: ignore[union-attr]
831
832
except OSError :
832
833
pass
@@ -927,10 +928,10 @@ async def read_response(self, disable_decoding: bool = False):
927
928
disable_decoding = disable_decoding
928
929
)
929
930
except asyncio .TimeoutError :
930
- await self .disconnect ()
931
+ await self .disconnect (nowait = True )
931
932
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
932
933
except OSError as e :
933
- await self .disconnect ()
934
+ await self .disconnect (nowait = True )
934
935
raise ConnectionError (
935
936
f"Error while reading from { self .host } :{ self .port } : { e .args } "
936
937
)
@@ -939,7 +940,7 @@ async def read_response(self, disable_decoding: bool = False):
939
940
# is subclass of Exception, not BaseException
940
941
raise
941
942
except Exception :
942
- await self .disconnect ()
943
+ await self .disconnect (nowait = True )
943
944
raise
944
945
945
946
if self .health_check_interval :
0 commit comments