Skip to content

Commit 4cf8dab

Browse files
committed
PythonParser is now resumable if _stream IO is interrupted
1 parent 3a121be commit 4cf8dab

File tree

1 file changed

+50
-16
lines changed

1 file changed

+50
-16
lines changed

redis/asyncio/connection.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,14 @@ async def read_response(
208208
class PythonParser(BaseParser):
209209
"""Plain Python parsing class"""
210210

211-
__slots__ = BaseParser.__slots__ + ("encoder",)
211+
__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
212212

213213
def __init__(self, socket_read_size: int):
214214
super().__init__(socket_read_size)
215215
self.encoder: Optional[Encoder] = None
216+
self._buffer = b""
217+
self._chunks = []
218+
self._pos = 0
216219

217220
def on_connect(self, connection: "Connection"):
218221
"""Called when the stream connects"""
@@ -229,6 +232,8 @@ def on_disconnect(self):
229232
self.encoder = None
230233

231234
async def can_read_destructive(self) -> bool:
235+
if self._buffer:
236+
return True
232237
if self._stream is None:
233238
raise RedisError("Buffer is closed.")
234239
try:
@@ -237,7 +242,19 @@ async def can_read_destructive(self) -> bool:
237242
except asyncio.TimeoutError:
238243
return False
239244

240-
async def read_response(
245+
async def read_response(self, disable_decoding: bool = False):
246+
if self._chunks:
247+
# augment parsing buffer with previously read data
248+
self._buffer += b"".join(self._chunks)
249+
self._chunks.clear()
250+
self._pos = 0
251+
response = await self._read_response(disable_decoding=disable_decoding)
252+
# Successfully parsing a response allows us to clear our parsing buffer
253+
self._buffer = b""
254+
self._chunks.clear()
255+
return response
256+
257+
async def _read_response(
241258
self, disable_decoding: bool = False
242259
) -> Union[EncodableT, ResponseError, None]:
243260
if not self._stream or not self.encoder:
@@ -282,7 +299,7 @@ async def read_response(
282299
if length == -1:
283300
return None
284301
response = [
285-
(await self.read_response(disable_decoding)) for _ in range(length)
302+
(await self._read_response(disable_decoding)) for _ in range(length)
286303
]
287304
if isinstance(response, bytes) and disable_decoding is False:
288305
response = self.encoder.decode(response)
@@ -293,25 +310,42 @@ async def _read(self, length: int) -> bytes:
293310
Read `length` bytes of data. These are assumed to be followed
294311
by a '\r\n' terminator which is subsequently discarded.
295312
"""
296-
if self._stream is None:
297-
raise RedisError("Buffer is closed.")
298-
try:
299-
data = await self._stream.readexactly(length + 2)
300-
except asyncio.IncompleteReadError as error:
301-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
302-
return data[:-2]
313+
want = length + 2
314+
end = self._pos + want
315+
if len(self._buffer) >= end:
316+
result = self._buffer[self._pos : end - 2]
317+
else:
318+
if self._stream is None:
319+
raise RedisError("Buffer is closed.")
320+
tail = self._buffer[self._pos :]
321+
try:
322+
data = await self._stream.readexactly(want - len(tail))
323+
except asyncio.IncompleteReadError as error:
324+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
325+
result = (tail + data)[:-2]
326+
self._chunks.append(data)
327+
self._pos += want
328+
return result
303329

304330
async def _readline(self) -> bytes:
305331
"""
306332
read an unknown number of bytes up to the next '\r\n'
307333
line separator, which is discarded.
308334
"""
309-
if self._stream is None:
310-
raise RedisError("Buffer is closed.")
311-
data = await self._stream.readline()
312-
if not data.endswith(b"\r\n"):
313-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
314-
return data[:-2]
335+
found = self._buffer.find(b"\r\n", self._pos)
336+
if found >= 0:
337+
result = self._buffer[self._pos : found]
338+
else:
339+
if self._stream is None:
340+
raise RedisError("Buffer is closed.")
341+
tail = self._buffer[self._pos :]
342+
data = await self._stream.readline()
343+
if not data.endswith(b"\r\n"):
344+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
345+
result = (tail + data)[:-2]
346+
self._chunks.append(data)
347+
self._pos += len(result) + 2
348+
return result
315349

316350

317351
class HiredisParser(BaseParser):

0 commit comments

Comments
 (0)