Skip to content

Commit 2534aee

Browse files
committed
Add test for parse resumability
1 parent 4cf8dab commit 2534aee

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

tests/test_asyncio/test_connection.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77

8+
import redis
89
from redis.asyncio.connection import (
910
Connection,
1011
PythonParser,
@@ -112,3 +113,77 @@ async def test_connect_timeout_error_without_retry():
112113
await conn.connect()
113114
assert conn._connect.call_count == 1
114115
assert str(e.value) == "Timeout connecting to server"
116+
117+
118+
class TestError(BaseException):
119+
pass
120+
121+
122+
class InterruptingReader:
123+
"""
124+
A class simulating an asyncio input buffer, but raising a
125+
special exception every other read.
126+
"""
127+
128+
def __init__(self, data):
129+
self.data = data
130+
self.counter = 0
131+
self.pos = 0
132+
133+
def tick(self):
134+
self.counter += 1
135+
# return
136+
if (self.counter % 2) == 0:
137+
raise TestError()
138+
139+
async def read(self, want):
140+
self.tick()
141+
want = 5
142+
result = self.data[self.pos : self.pos + want]
143+
self.pos += len(result)
144+
return result
145+
146+
async def readline(self):
147+
self.tick()
148+
find = self.data.find(b"\n", self.pos)
149+
if find >= 0:
150+
result = self.data[self.pos : find + 1]
151+
else:
152+
result = self.data[self.pos :]
153+
self.pos += len(result)
154+
return result
155+
156+
async def readexactly(self, length):
157+
self.tick()
158+
result = self.data[self.pos : self.pos + length]
159+
if len(result) < length:
160+
raise asyncio.IncompleteReadError(result, None)
161+
self.pos += len(result)
162+
return result
163+
164+
165+
async def test_connection_parse_response_resume(r: redis.Redis):
166+
"""
167+
This test verifies that the Connection parser,
168+
be that PythonParser or HiredisParser,
169+
can be interrupted at IO time and then resume parsing.
170+
"""
171+
conn = Connection(**r.connection_pool.connection_kwargs)
172+
await conn.connect()
173+
message = (
174+
b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
175+
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
176+
)
177+
178+
conn._parser._stream = InterruptingReader(message)
179+
for i in range(100):
180+
try:
181+
response = await conn.read_response()
182+
break
183+
except TestError:
184+
pass
185+
186+
else:
187+
pytest.fail("didn't receive a response")
188+
assert response
189+
assert i > 0

0 commit comments

Comments
 (0)