Skip to content

Commit 8cbf7f5

Browse files
authored
Support client side caching with ConnectionPool (redis#3099)
* sync * async * fixs connection mocks * fix async connection mock * fix test_asyncio/test_connection.py::test_single_connection * add test for cache blacklist and flushdb at the end of each test * fix review comments
1 parent 6d77c6d commit 8cbf7f5

File tree

10 files changed

+318
-246
lines changed

10 files changed

+318
-246
lines changed

redis/cache.py renamed to redis/_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ class _LocalCache:
178178
"""
179179

180180
def __init__(
181-
self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs
181+
self,
182+
max_size: int = 100,
183+
ttl: int = 0,
184+
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
185+
**kwargs,
182186
):
183187
self.max_size = max_size
184188
self.ttl = ttl

redis/asyncio/client.py

Lines changed: 18 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
cast,
2626
)
2727

28+
from redis._cache import (
29+
DEFAULT_BLACKLIST,
30+
DEFAULT_EVICTION_POLICY,
31+
DEFAULT_WHITELIST,
32+
_LocalCache,
33+
)
2834
from redis._parsers.helpers import (
2935
_RedisCallbacks,
3036
_RedisCallbacksRESP2,
@@ -39,12 +45,6 @@
3945
)
4046
from redis.asyncio.lock import Lock
4147
from redis.asyncio.retry import Retry
42-
from redis.cache import (
43-
DEFAULT_BLACKLIST,
44-
DEFAULT_EVICTION_POLICY,
45-
DEFAULT_WHITELIST,
46-
_LocalCache,
47-
)
4848
from redis.client import (
4949
EMPTY_RESPONSE,
5050
NEVER_DECODE,
@@ -67,7 +67,7 @@
6767
TimeoutError,
6868
WatchError,
6969
)
70-
from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT
70+
from redis.typing import ChannelT, EncodableT, KeyT
7171
from redis.utils import (
7272
HIREDIS_AVAILABLE,
7373
_set_info_logger,
@@ -294,6 +294,13 @@ def __init__(
294294
"lib_version": lib_version,
295295
"redis_connect_func": redis_connect_func,
296296
"protocol": protocol,
297+
"cache_enable": cache_enable,
298+
"client_cache": client_cache,
299+
"cache_max_size": cache_max_size,
300+
"cache_ttl": cache_ttl,
301+
"cache_eviction_policy": cache_eviction_policy,
302+
"cache_blacklist": cache_blacklist,
303+
"cache_whitelist": cache_whitelist,
297304
}
298305
# based on input, setup appropriate connection args
299306
if unix_socket_path is not None:
@@ -350,16 +357,6 @@ def __init__(
350357
# on a set of redis commands
351358
self._single_conn_lock = asyncio.Lock()
352359

353-
self.client_cache = client_cache
354-
if cache_enable:
355-
self.client_cache = _LocalCache(
356-
cache_max_size, cache_ttl, cache_eviction_policy
357-
)
358-
if self.client_cache is not None:
359-
self.cache_blacklist = cache_blacklist
360-
self.cache_whitelist = cache_whitelist
361-
self.client_cache_initialized = False
362-
363360
def __repr__(self):
364361
return (
365362
f"<{self.__class__.__module__}.{self.__class__.__name__}"
@@ -374,10 +371,6 @@ async def initialize(self: _RedisT) -> _RedisT:
374371
async with self._single_conn_lock:
375372
if self.connection is None:
376373
self.connection = await self.connection_pool.get_connection("_")
377-
if self.client_cache is not None:
378-
self.connection._parser.set_invalidation_push_handler(
379-
self._cache_invalidation_process
380-
)
381374
return self
382375

383376
def set_response_callback(self, command: str, callback: ResponseCallbackT):
@@ -596,8 +589,6 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
596589
close_connection_pool is None and self.auto_close_connection_pool
597590
):
598591
await self.connection_pool.disconnect()
599-
if self.client_cache:
600-
self.client_cache.flush()
601592

602593
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
603594
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
@@ -626,89 +617,28 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
626617
):
627618
raise error
628619

629-
def _cache_invalidation_process(
630-
self, data: List[Union[str, Optional[List[str]]]]
631-
) -> None:
632-
"""
633-
Invalidate (delete) all redis commands associated with a specific key.
634-
`data` is a list of strings, where the first string is the invalidation message
635-
and the second string is the list of keys to invalidate.
636-
(if the list of keys is None, then all keys are invalidated)
637-
"""
638-
if data[1] is not None:
639-
for key in data[1]:
640-
self.client_cache.invalidate(str_if_bytes(key))
641-
else:
642-
self.client_cache.flush()
643-
644-
async def _get_from_local_cache(self, command: str):
645-
"""
646-
If the command is in the local cache, return the response
647-
"""
648-
if (
649-
self.client_cache is None
650-
or command[0] in self.cache_blacklist
651-
or command[0] not in self.cache_whitelist
652-
):
653-
return None
654-
while not self.connection._is_socket_empty():
655-
await self.connection.read_response(push_request=True)
656-
return self.client_cache.get(command)
657-
658-
def _add_to_local_cache(
659-
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
660-
):
661-
"""
662-
Add the command and response to the local cache if the command
663-
is allowed to be cached
664-
"""
665-
if (
666-
self.client_cache is not None
667-
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
668-
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
669-
):
670-
self.client_cache.set(command, response, keys)
671-
672-
def delete_from_local_cache(self, command: str):
673-
"""
674-
Delete the command from the local cache
675-
"""
676-
try:
677-
self.client_cache.delete(command)
678-
except AttributeError:
679-
pass
680-
681620
# COMMAND EXECUTION AND PROTOCOL PARSING
682621
async def execute_command(self, *args, **options):
683622
"""Execute a command and return a parsed response"""
684623
await self.initialize()
685624
command_name = args[0]
686625
keys = options.pop("keys", None) # keys are used only for client side caching
687-
response_from_cache = await self._get_from_local_cache(args)
626+
pool = self.connection_pool
627+
conn = self.connection or await pool.get_connection(command_name, **options)
628+
response_from_cache = await conn._get_from_local_cache(args)
688629
if response_from_cache is not None:
689630
return response_from_cache
690631
else:
691-
pool = self.connection_pool
692-
conn = self.connection or await pool.get_connection(command_name, **options)
693-
694632
if self.single_connection_client:
695633
await self._single_conn_lock.acquire()
696634
try:
697-
if self.client_cache is not None and not self.client_cache_initialized:
698-
await conn.retry.call_with_retry(
699-
lambda: self._send_command_parse_response(
700-
conn, "CLIENT", *("CLIENT", "TRACKING", "ON")
701-
),
702-
lambda error: self._disconnect_raise(conn, error),
703-
)
704-
self.client_cache_initialized = True
705635
response = await conn.retry.call_with_retry(
706636
lambda: self._send_command_parse_response(
707637
conn, command_name, *args, **options
708638
),
709639
lambda error: self._disconnect_raise(conn, error),
710640
)
711-
self._add_to_local_cache(args, response, keys)
641+
conn._add_to_local_cache(args, response, keys)
712642
return response
713643
finally:
714644
if self.single_connection_client:

redis/asyncio/connection.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,15 @@
4747
ResponseError,
4848
TimeoutError,
4949
)
50-
from redis.typing import EncodableT
50+
from redis.typing import EncodableT, KeysT, ResponseT
5151
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
5252

53+
from .._cache import (
54+
DEFAULT_BLACKLIST,
55+
DEFAULT_EVICTION_POLICY,
56+
DEFAULT_WHITELIST,
57+
_LocalCache,
58+
)
5359
from .._parsers import (
5460
BaseParser,
5561
Encoder,
@@ -114,6 +120,9 @@ class AbstractConnection:
114120
"encoder",
115121
"ssl_context",
116122
"protocol",
123+
"client_cache",
124+
"cache_blacklist",
125+
"cache_whitelist",
117126
"_reader",
118127
"_writer",
119128
"_parser",
@@ -148,6 +157,13 @@ def __init__(
148157
encoder_class: Type[Encoder] = Encoder,
149158
credential_provider: Optional[CredentialProvider] = None,
150159
protocol: Optional[int] = 2,
160+
cache_enable: bool = False,
161+
client_cache: Optional[_LocalCache] = None,
162+
cache_max_size: int = 100,
163+
cache_ttl: int = 0,
164+
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
165+
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
166+
cache_whitelist: List[str] = DEFAULT_WHITELIST,
151167
):
152168
if (username or password) and credential_provider is not None:
153169
raise DataError(
@@ -205,6 +221,14 @@ def __init__(
205221
if p < 2 or p > 3:
206222
raise ConnectionError("protocol must be either 2 or 3")
207223
self.protocol = protocol
224+
if cache_enable:
225+
_cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy)
226+
else:
227+
_cache = None
228+
self.client_cache = client_cache if client_cache is not None else _cache
229+
if self.client_cache is not None:
230+
self.cache_blacklist = cache_blacklist
231+
self.cache_whitelist = cache_whitelist
208232

209233
def __del__(self, _warnings: Any = warnings):
210234
# For some reason, the individual streams don't get properly garbage
@@ -395,6 +419,11 @@ async def on_connect(self) -> None:
395419
# if a database is specified, switch to it. Also pipeline this
396420
if self.db:
397421
await self.send_command("SELECT", self.db)
422+
# if client caching is enabled, start tracking
423+
if self.client_cache:
424+
await self.send_command("CLIENT", "TRACKING", "ON")
425+
await self.read_response()
426+
self._parser.set_invalidation_push_handler(self._cache_invalidation_process)
398427

399428
# read responses from pipeline
400429
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -429,6 +458,9 @@ async def disconnect(self, nowait: bool = False) -> None:
429458
raise TimeoutError(
430459
f"Timed out closing connection after {self.socket_connect_timeout}"
431460
) from None
461+
finally:
462+
if self.client_cache:
463+
self.client_cache.flush()
432464

433465
async def _send_ping(self):
434466
"""Send PING, expect PONG in return"""
@@ -646,10 +678,62 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
646678
output.append(SYM_EMPTY.join(pieces))
647679
return output
648680

649-
def _is_socket_empty(self):
681+
def _socket_is_empty(self):
650682
"""Check if the socket is empty"""
651683
return not self._reader.at_eof()
652684

685+
def _cache_invalidation_process(
686+
self, data: List[Union[str, Optional[List[str]]]]
687+
) -> None:
688+
"""
689+
Invalidate (delete) all redis commands associated with a specific key.
690+
`data` is a list of strings, where the first string is the invalidation message
691+
and the second string is the list of keys to invalidate.
692+
(if the list of keys is None, then all keys are invalidated)
693+
"""
694+
if data[1] is not None:
695+
self.client_cache.flush()
696+
else:
697+
for key in data[1]:
698+
self.client_cache.invalidate(str_if_bytes(key))
699+
700+
async def _get_from_local_cache(self, command: str):
701+
"""
702+
If the command is in the local cache, return the response
703+
"""
704+
if (
705+
self.client_cache is None
706+
or command[0] in self.cache_blacklist
707+
or command[0] not in self.cache_whitelist
708+
):
709+
return None
710+
while not self._socket_is_empty():
711+
await self.read_response(push_request=True)
712+
return self.client_cache.get(command)
713+
714+
def _add_to_local_cache(
715+
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
716+
):
717+
"""
718+
Add the command and response to the local cache if the command
719+
is allowed to be cached
720+
"""
721+
if (
722+
self.client_cache is not None
723+
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
724+
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
725+
):
726+
self.client_cache.set(command, response, keys)
727+
728+
def delete_from_local_cache(self, command: str):
729+
"""
730+
Delete the command from the local cache
731+
"""
732+
try:
733+
self.client_cache.delete(command)
734+
except AttributeError:
735+
pass
736+
653737

654738
class Connection(AbstractConnection):
655739
"Manages TCP communication to and from a Redis server"

0 commit comments

Comments
 (0)