Skip to content

Commit 59ece83

Browse files
committed
Add tests for varions connection handshakes, async
1 parent baaec0d commit 59ece83

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

tests/test_asyncio/test_connect.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import logging
33
import socket
44
import ssl
5+
from unittest.mock import patch
56

67
import pytest
78
from redis.asyncio.connection import (
89
Connection,
10+
ResponseError,
911
SSLConnection,
1012
UnixDomainSocketConnection,
1113
)
@@ -61,6 +63,90 @@ async def test_tcp_ssl_connect(tcp_address):
6163
await conn.disconnect()
6264

6365

66+
@pytest.mark.parametrize(
67+
("use_server_ver", "use_protocol", "use_auth", "use_client_name"),
68+
[
69+
(5, 2, False, True),
70+
(5, 2, True, True),
71+
(5, 3, True, True),
72+
(6, 2, False, True),
73+
(6, 2, True, True),
74+
(6, 3, False, False),
75+
(6, 3, True, False),
76+
(6, 3, False, True),
77+
(6, 3, True, True),
78+
],
79+
)
80+
# @pytest.mark.parametrize("use_protocol", [2, 3])
81+
# @pytest.mark.parametrize("use_auth", [False, True])
82+
async def test_tcp_auth(
83+
tcp_address, use_protocol, use_auth, use_server_ver, use_client_name
84+
):
85+
"""
86+
Test that various initial handshake cases are handled correctly by the client
87+
"""
88+
got_auth = []
89+
got_protocol = None
90+
got_name = None
91+
92+
def on_auth(self, auth):
93+
got_auth[:] = auth
94+
95+
def on_protocol(self, proto):
96+
nonlocal got_protocol
97+
got_protocol = proto
98+
99+
def on_setname(self, name):
100+
nonlocal got_name
101+
got_name = name
102+
103+
def get_server_version(self):
104+
return use_server_ver
105+
106+
if use_auth:
107+
auth_args = {"username": "myuser", "password": "mypassword"}
108+
else:
109+
auth_args = {}
110+
got_protocol = None
111+
host, port = tcp_address
112+
conn = Connection(
113+
host=host,
114+
port=port,
115+
client_name=_CLIENT_NAME if use_client_name else None,
116+
socket_timeout=10,
117+
protocol=use_protocol,
118+
**auth_args,
119+
)
120+
try:
121+
with patch.multiple(
122+
resp.RespServer,
123+
on_auth=on_auth,
124+
get_server_version=get_server_version,
125+
on_protocol=on_protocol,
126+
on_setname=on_setname,
127+
):
128+
if use_server_ver < 6 and use_protocol > 2:
129+
with pytest.raises(ResponseError):
130+
await _assert_connect(conn, tcp_address)
131+
return
132+
133+
await _assert_connect(conn, tcp_address)
134+
if use_protocol == 3:
135+
assert got_protocol == use_protocol
136+
if use_auth:
137+
if use_server_ver < 6:
138+
assert got_auth == ["mypassword"]
139+
else:
140+
assert got_auth == ["myuser", "mypassword"]
141+
142+
if use_client_name:
143+
assert got_name == _CLIENT_NAME
144+
else:
145+
assert got_name is None
146+
finally:
147+
await conn.disconnect()
148+
149+
64150
async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
65151
stop_event = asyncio.Event()
66152
finished = asyncio.Event()

0 commit comments

Comments
 (0)