Skip to content

Commit 652ca79

Browse files
kristjanvalurakx
andauthored
Add nowait flag to asyncio.Connection.disconnect() (#2356)
* Don't wait for disconnect() when handling errors. This can result in other errors such as timeouts. * add CHANGES * Update redis/asyncio/connection.py Co-authored-by: Aarni Koskela <akx@iki.fi> * await a task to try to diagnose unittest failures in CI Co-authored-by: Aarni Koskela <akx@iki.fi>
1 parent 9fe8366 commit 652ca79

File tree

3 files changed

+55
-40
lines changed

3 files changed

+55
-40
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* add `nowait` flag to `asyncio.Connection.disconnect()`
12
* Update README.md links
23
* Fix timezone handling for datetime to unixtime conversions
34
* Fix start_id type for XAUTOCLAIM

redis/asyncio/connection.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ async def on_connect(self) -> None:
836836
if str_if_bytes(await self.read_response()) != "OK":
837837
raise ConnectionError("Invalid Database")
838838

839-
async def disconnect(self) -> None:
839+
async def disconnect(self, nowait: bool = False) -> None:
840840
"""Disconnects from the Redis server"""
841841
try:
842842
async with async_timeout.timeout(self.socket_connect_timeout):
@@ -846,8 +846,9 @@ async def disconnect(self) -> None:
846846
try:
847847
if os.getpid() == self.pid:
848848
self._writer.close() # type: ignore[union-attr]
849-
# py3.6 doesn't have this method
850-
if hasattr(self._writer, "wait_closed"):
849+
# wait for close to finish, except when handling errors and
850+
# forcefully disconnecting.
851+
if not nowait:
851852
await self._writer.wait_closed() # type: ignore[union-attr]
852853
except OSError:
853854
pass
@@ -902,10 +903,10 @@ async def send_packed_command(
902903
self._writer.writelines(command)
903904
await self._writer.drain()
904905
except asyncio.TimeoutError:
905-
await self.disconnect()
906+
await self.disconnect(nowait=True)
906907
raise TimeoutError("Timeout writing to socket") from None
907908
except OSError as e:
908-
await self.disconnect()
909+
await self.disconnect(nowait=True)
909910
if len(e.args) == 1:
910911
err_no, errmsg = "UNKNOWN", e.args[0]
911912
else:
@@ -915,7 +916,7 @@ async def send_packed_command(
915916
f"Error {err_no} while writing to socket. {errmsg}."
916917
) from e
917918
except Exception:
918-
await self.disconnect()
919+
await self.disconnect(nowait=True)
919920
raise
920921

921922
async def send_command(self, *args: Any, **kwargs: Any) -> None:
@@ -931,7 +932,7 @@ async def can_read(self, timeout: float = 0):
931932
try:
932933
return await self._parser.can_read(timeout)
933934
except OSError as e:
934-
await self.disconnect()
935+
await self.disconnect(nowait=True)
935936
raise ConnectionError(
936937
f"Error while reading from {self.host}:{self.port}: {e.args}"
937938
)
@@ -949,15 +950,15 @@ async def read_response(self, disable_decoding: bool = False):
949950
disable_decoding=disable_decoding
950951
)
951952
except asyncio.TimeoutError:
952-
await self.disconnect()
953+
await self.disconnect(nowait=True)
953954
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
954955
except OSError as e:
955-
await self.disconnect()
956+
await self.disconnect(nowait=True)
956957
raise ConnectionError(
957958
f"Error while reading from {self.host}:{self.port} : {e.args}"
958959
)
959960
except Exception:
960-
await self.disconnect()
961+
await self.disconnect(nowait=True)
961962
raise
962963

963964
if self.health_check_interval:

tests/test_asyncio/test_pubsub.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -819,14 +819,16 @@ async def mysetup(self, r, method):
819819
"type": "subscribe",
820820
}
821821

822-
async def mycleanup(self):
822+
async def myfinish(self):
823823
message = await self.messages.get()
824824
assert message == {
825825
"channel": b"foo",
826826
"data": 1,
827827
"pattern": None,
828828
"type": "subscribe",
829829
}
830+
831+
async def mykill(self):
830832
# kill thread
831833
async with self.cond:
832834
self.state = 4 # quit
@@ -836,41 +838,52 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method):
836838
"""
837839
Test that a socket error will cause reconnect
838840
"""
839-
async with async_timeout.timeout(self.timeout):
840-
await self.mysetup(r, method)
841-
# now, disconnect the connection, and wait for it to be re-established
842-
async with self.cond:
843-
assert self.state == 0
844-
self.state = 1
845-
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
846-
mockobj.read_response.side_effect = socket.error
847-
mockobj.can_read.side_effect = socket.error
848-
# wait until task noticies the disconnect until we undo the patch
849-
await self.cond.wait_for(lambda: self.state >= 2)
850-
assert not self.pubsub.connection.is_connected
851-
# it is in a disconnecte state
852-
# wait for reconnect
853-
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
854-
assert self.state == 3
841+
try:
842+
async with async_timeout.timeout(self.timeout):
843+
await self.mysetup(r, method)
844+
# now, disconnect the connection, and wait for it to be re-established
845+
async with self.cond:
846+
assert self.state == 0
847+
self.state = 1
848+
with mock.patch.object(self.pubsub.connection, "_parser") as m:
849+
m.read_response.side_effect = socket.error
850+
m.can_read.side_effect = socket.error
851+
# wait until task noticies the disconnect until we
852+
# undo the patch
853+
await self.cond.wait_for(lambda: self.state >= 2)
854+
assert not self.pubsub.connection.is_connected
855+
# it is in a disconnecte state
856+
# wait for reconnect
857+
await self.cond.wait_for(
858+
lambda: self.pubsub.connection.is_connected
859+
)
860+
assert self.state == 3
855861

856-
await self.mycleanup()
862+
await self.myfinish()
863+
finally:
864+
await self.mykill()
857865

858866
async def test_reconnect_disconnect(self, r: redis.Redis, method):
859867
"""
860868
Test that a manual disconnect() will cause reconnect
861869
"""
862-
async with async_timeout.timeout(self.timeout):
863-
await self.mysetup(r, method)
864-
# now, disconnect the connection, and wait for it to be re-established
865-
async with self.cond:
866-
self.state = 1
867-
await self.pubsub.connection.disconnect()
868-
assert not self.pubsub.connection.is_connected
869-
# wait for reconnect
870-
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
871-
assert self.state == 3
872-
873-
await self.mycleanup()
870+
try:
871+
async with async_timeout.timeout(self.timeout):
872+
await self.mysetup(r, method)
873+
# now, disconnect the connection, and wait for it to be re-established
874+
async with self.cond:
875+
self.state = 1
876+
await self.pubsub.connection.disconnect()
877+
assert not self.pubsub.connection.is_connected
878+
# wait for reconnect
879+
await self.cond.wait_for(
880+
lambda: self.pubsub.connection.is_connected
881+
)
882+
assert self.state == 3
883+
884+
await self.myfinish()
885+
finally:
886+
await self.mykill()
874887

875888
async def loop(self):
876889
# reader loop, performing state transitions as it

0 commit comments

Comments
 (0)