From 25e85e51e57b7aae9eb8fc77cfb0a45a07a501a7 Mon Sep 17 00:00:00 2001 From: Mehdi ABAAKOUK Date: Thu, 16 Mar 2023 12:51:19 +0100 Subject: [PATCH] fix: replace async_timeout by asyncio.timeout (#2602) async_timeout does not support python 3.11 https://github.com/aio-libs/async-timeout/pull/295 And have two years old annoying bugs: https://github.com/aio-libs/async-timeout/issues/229 https://github.com/redis/redis-py/issues/2551 Since asyncio.timeout has been shipped in python 3.11, we should start using it. Partially fixes 2551 --- CHANGES | 1 + redis/asyncio/connection.py | 21 +++++++++++++-------- setup.py | 2 +- tests/test_asyncio/test_pubsub.py | 22 +++++++++++++--------- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/CHANGES b/CHANGES index e83660d6ac..3e4eba44be 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602) * Add test and fix async HiredisParser when reading during a disconnect() (#2349) * Use hiredis-py pack_command if available. * Support `.unlink()` in ClusterPipeline diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 056998e9e0..93db37e46d 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -5,6 +5,7 @@ import os import socket import ssl +import sys import threading import weakref from itertools import chain @@ -24,7 +25,11 @@ ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse -import async_timeout +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + from redis.asyncio.retry import Retry from redis.backoff import NoBackoff @@ -242,7 +247,7 @@ async def can_read_destructive(self) -> bool: if self._stream is None: raise RedisError("Buffer is closed.") try: - async with async_timeout.timeout(0): + async with async_timeout(0): return await self._stream.read(1) except asyncio.TimeoutError: return False @@ -380,7 +385,7 @@ async def can_read_destructive(self): if self._reader.gets(): return True try: - async with async_timeout.timeout(0): + async with async_timeout(0): return await self.read_from_socket() except asyncio.TimeoutError: return False @@ -635,7 +640,7 @@ async def connect(self): async def _connect(self): """Create a TCP socket connection""" - async with async_timeout.timeout(self.socket_connect_timeout): + async with async_timeout(self.socket_connect_timeout): reader, writer = await asyncio.open_connection( host=self.host, port=self.port, @@ -722,7 +727,7 @@ async def on_connect(self) -> None: async def disconnect(self, nowait: bool = False) -> None: """Disconnects from the Redis server""" try: - async with async_timeout.timeout(self.socket_connect_timeout): + async with async_timeout(self.socket_connect_timeout): self._parser.on_disconnect() if not self.is_connected: return @@ -827,7 +832,7 @@ async def read_response( read_timeout = timeout if timeout is not None else self.socket_timeout try: if read_timeout is not None: - async with async_timeout.timeout(read_timeout): + async with async_timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) @@ -1118,7 +1123,7 @@ def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: return pieces async def _connect(self): - async with async_timeout.timeout(self.socket_connect_timeout): + async with async_timeout(self.socket_connect_timeout): reader, writer = await asyncio.open_unix_connection(path=self.path) self._reader = reader self._writer = writer @@ -1589,7 +1594,7 @@ async def get_connection(self, command_name, *keys, **options): # self.timeout then raise a ``ConnectionError``. connection = None try: - async with async_timeout.timeout(self.timeout): + async with async_timeout(self.timeout): connection = await self.pool.get() except (asyncio.QueueEmpty, asyncio.TimeoutError): # Note that this is not caught by the redis client and will be diff --git a/setup.py b/setup.py index 060e9da7b0..ceeeb81699 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ install_requires=[ 'importlib-metadata >= 1.0; python_version < "3.8"', 'typing-extensions; python_version<"3.8"', - "async-timeout>=4.0.2", + 'async-timeout>=4.0.2; python_version<"3.11"', ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index c2a9130e83..0df7847e66 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -5,7 +5,11 @@ from typing import Optional from unittest.mock import patch -import async_timeout +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + import pytest import pytest_asyncio @@ -21,7 +25,7 @@ def with_timeout(t): def wrapper(corofunc): @functools.wraps(corofunc) async def run(*args, **kwargs): - async with async_timeout.timeout(t): + async with async_timeout(t): return await corofunc(*args, **kwargs) return run @@ -648,7 +652,7 @@ async def test_reconnect_listen(self, r: redis.Redis, pubsub): async def loop(): # must make sure the task exits - async with async_timeout.timeout(2): + async with async_timeout(2): nonlocal interrupt await pubsub.subscribe("foo") while True: @@ -677,7 +681,7 @@ async def loop_step(): task = asyncio.get_running_loop().create_task(loop()) # get the initial connect message - async with async_timeout.timeout(1): + async with async_timeout(1): message = await messages.get() assert message == { "channel": b"foo", @@ -776,7 +780,7 @@ def callback(message): if n == 1: break await asyncio.sleep(0.1) - async with async_timeout.timeout(0.1): + async with async_timeout(0.1): message = await messages.get() task.cancel() # we expect a cancelled error, not the Runtime error @@ -839,7 +843,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method): Test that a socket error will cause reconnect """ try: - async with async_timeout.timeout(self.timeout): + async with async_timeout(self.timeout): await self.mysetup(r, method) # now, disconnect the connection, and wait for it to be re-established async with self.cond: @@ -868,7 +872,7 @@ async def test_reconnect_disconnect(self, r: redis.Redis, method): Test that a manual disconnect() will cause reconnect """ try: - async with async_timeout.timeout(self.timeout): + async with async_timeout(self.timeout): await self.mysetup(r, method) # now, disconnect the connection, and wait for it to be re-established async with self.cond: @@ -923,7 +927,7 @@ async def loop_step_get_message(self): async def loop_step_listen(self): # get a single message via listen() try: - async with async_timeout.timeout(0.1): + async with async_timeout(0.1): async for message in self.pubsub.listen(): await self.messages.put(message) return True @@ -947,7 +951,7 @@ async def test_outer_timeout(self, r: redis.Redis): assert pubsub.connection.is_connected async def get_msg_or_timeout(timeout=0.1): - async with async_timeout.timeout(timeout): + async with async_timeout(timeout): # blocking method to return messages while True: response = await pubsub.parse_response(block=True)