|
1 | 1 | import asyncio |
| 2 | +import socket |
2 | 3 | import types |
| 4 | +from unittest.mock import patch |
3 | 5 |
|
4 | 6 | import pytest |
5 | 7 |
|
6 | | -from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection |
7 | | -from redis.exceptions import InvalidResponse |
| 8 | +from redis.asyncio.connection import ( |
| 9 | + Connection, |
| 10 | + PythonParser, |
| 11 | + UnixDomainSocketConnection, |
| 12 | +) |
| 13 | +from redis.asyncio.retry import Retry |
| 14 | +from redis.backoff import NoBackoff |
| 15 | +from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError |
8 | 16 | from redis.utils import HIREDIS_AVAILABLE |
9 | 17 | from tests.conftest import skip_if_server_version_lt |
10 | 18 |
|
@@ -60,3 +68,44 @@ async def test_socket_param_regression(r): |
60 | 68 | async def test_can_run_concurrent_commands(r): |
61 | 69 | assert await r.ping() is True |
62 | 70 | assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) |
| 71 | + |
| 72 | + |
| 73 | +async def test_connect_retry_on_timeout_error(): |
| 74 | + """Test that the _connect function is retried in case of a timeout""" |
| 75 | + conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3)) |
| 76 | + origin_connect = conn._connect |
| 77 | + conn._connect = mock.AsyncMock() |
| 78 | + |
| 79 | + async def mock_connect(): |
| 80 | + # connect only on the last retry |
| 81 | + if conn._connect.call_count <= 2: |
| 82 | + raise socket.timeout |
| 83 | + else: |
| 84 | + return await origin_connect() |
| 85 | + |
| 86 | + conn._connect.side_effect = mock_connect |
| 87 | + await conn.connect() |
| 88 | + assert conn._connect.call_count == 3 |
| 89 | + |
| 90 | + |
| 91 | +async def test_connect_without_retry_on_os_error(): |
| 92 | + """Test that the _connect function is not being retried in case of a OSError""" |
| 93 | + with patch.object(Connection, "_connect") as _connect: |
| 94 | + _connect.side_effect = OSError("") |
| 95 | + conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2)) |
| 96 | + with pytest.raises(ConnectionError): |
| 97 | + await conn.connect() |
| 98 | + assert _connect.call_count == 1 |
| 99 | + |
| 100 | + |
| 101 | +async def test_connect_timeout_error_without_retry(): |
| 102 | + """Test that the _connect function is not being retried if retry_on_timeout is |
| 103 | + set to False""" |
| 104 | + conn = Connection(retry_on_timeout=False) |
| 105 | + conn._connect = mock.AsyncMock() |
| 106 | + conn._connect.side_effect = socket.timeout |
| 107 | + |
| 108 | + with pytest.raises(TimeoutError) as e: |
| 109 | + await conn.connect() |
| 110 | + assert conn._connect.call_count == 1 |
| 111 | + assert str(e.value) == "Timeout connecting to server" |
0 commit comments