Skip to content

Commit ecdffb1

Browse files
dvora-hchayim
authored andcommitted
Fix bug: client side caching causes unexpected disconnections (async version) (#3165)
* fix disconnects * skip test in cluster * add test * save return value from handle_push_response (without it 'read_response' return the push message) * insert return response from cache to the try block to prevent connection leak * enable to get connection with data avaliable to read in csc mode and change can_read_destructive to not read data * fix check if socket is empty (at_eof() can return False but this doesn't mean there's definitely more data to read) --------- Co-authored-by: Chayim <chayim@users.noreply.github.com>
1 parent b2c2547 commit ecdffb1

File tree

5 files changed

+77
-25
lines changed

5 files changed

+77
-25
lines changed

redis/_parsers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool:
182182
return True
183183
try:
184184
async with async_timeout(0):
185-
return await self._stream.read(1)
185+
return self._stream.at_eof()
186186
except TimeoutError:
187187
return False
188188

redis/_parsers/resp3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ async def _read_response(
261261
)
262262
for _ in range(int(response))
263263
]
264-
await self.handle_push_response(response, disable_decoding, push_request)
264+
response = await self.handle_push_response(
265+
response, disable_decoding, push_request
266+
)
265267
else:
266268
raise InvalidResponse(f"Protocol Error: {raw!r}")
267269

redis/asyncio/client.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -625,25 +625,27 @@ async def execute_command(self, *args, **options):
625625
pool = self.connection_pool
626626
conn = self.connection or await pool.get_connection(command_name, **options)
627627
response_from_cache = await conn._get_from_local_cache(args)
628-
if response_from_cache is not None:
629-
return response_from_cache
630-
else:
631-
if self.single_connection_client:
632-
await self._single_conn_lock.acquire()
633-
try:
634-
response = await conn.retry.call_with_retry(
635-
lambda: self._send_command_parse_response(
636-
conn, command_name, *args, **options
637-
),
638-
lambda error: self._disconnect_raise(conn, error),
639-
)
640-
conn._add_to_local_cache(args, response, keys)
641-
return response
642-
finally:
643-
if self.single_connection_client:
644-
self._single_conn_lock.release()
645-
if not self.connection:
646-
await pool.release(conn)
628+
try:
629+
if response_from_cache is not None:
630+
return response_from_cache
631+
else:
632+
try:
633+
if self.single_connection_client:
634+
await self._single_conn_lock.acquire()
635+
response = await conn.retry.call_with_retry(
636+
lambda: self._send_command_parse_response(
637+
conn, command_name, *args, **options
638+
),
639+
lambda error: self._disconnect_raise(conn, error),
640+
)
641+
conn._add_to_local_cache(args, response, keys)
642+
return response
643+
finally:
644+
if self.single_connection_client:
645+
self._single_conn_lock.release()
646+
finally:
647+
if not self.connection:
648+
await pool.release(conn)
647649

648650
async def parse_response(
649651
self, connection: Connection, command_name: Union[str, bytes], **options

redis/asyncio/connection.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
684684

685685
def _socket_is_empty(self):
686686
"""Check if the socket is empty"""
687-
return not self._reader.at_eof()
687+
return len(self._reader._buffer) == 0
688688

689689
def _cache_invalidation_process(
690690
self, data: List[Union[str, Optional[List[str]]]]
@@ -1191,12 +1191,18 @@ def make_connection(self):
11911191
async def ensure_connection(self, connection: AbstractConnection):
11921192
"""Ensure that the connection object is connected and valid"""
11931193
await connection.connect()
1194-
# connections that the pool provides should be ready to send
1195-
# a command. if not, the connection was either returned to the
1194+
# if client caching is not enabled connections that the pool
1195+
# provides should be ready to send a command.
1196+
# if not, the connection was either returned to the
11961197
# pool before all data has been read or the socket has been
11971198
# closed. either way, reconnect and verify everything is good.
1199+
# (if caching enabled the connection will not always be ready
1200+
# to send a command because it may contain invalidation messages)
11981201
try:
1199-
if await connection.can_read_destructive():
1202+
if (
1203+
await connection.can_read_destructive()
1204+
and connection.client_cache is None
1205+
):
12001206
raise ConnectionError("Connection has data") from None
12011207
except (ConnectionError, OSError):
12021208
await connection.disconnect()

tests/test_asyncio/test_cache.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r):
142142
check = cache.get(("LRANGE", "mylist", 0, -1))
143143
assert check == [b"baz", b"bar", b"foo"]
144144

145+
@pytest.mark.onlynoncluster
146+
@pytest.mark.parametrize(
147+
"r",
148+
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
149+
indirect=True,
150+
)
151+
async def test_csc_not_cause_disconnects(self, r):
152+
r, cache = r
153+
id1 = await r.client_id()
154+
await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1})
155+
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
156+
id2 = await r.client_id()
157+
158+
# client should get value from client cache
159+
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
160+
assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [
161+
"1",
162+
"1",
163+
"1",
164+
"1",
165+
"1",
166+
]
167+
168+
await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2})
169+
id3 = await r.client_id()
170+
# client should get value from redis server post invalidate messages
171+
assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"]
172+
173+
await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3})
174+
# need to check that we get correct value 3 and not 2
175+
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
176+
# client should get value from client cache
177+
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
178+
179+
await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4})
180+
# need to check that we get correct value 4 and not 3
181+
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
182+
# client should get value from client cache
183+
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
184+
id4 = await r.client_id()
185+
assert id1 == id2 == id3 == id4
186+
145187

146188
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
147189
@pytest.mark.onlycluster

0 commit comments

Comments
 (0)