diff --git a/dev_requirements.txt b/dev_requirements.txt index ef3b1aa22d..48ec278d83 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,6 @@ click==8.0.4 black==24.3.0 +cachetools flake8==5.0.4 flake8-isort==6.0.0 flynt~=0.69.0 diff --git a/dockers/sentinel.conf b/dockers/sentinel.conf index 1a33f53344..75f711e5d4 100644 --- a/dockers/sentinel.conf +++ b/dockers/sentinel.conf @@ -1,4 +1,5 @@ -sentinel monitor redis-py-test 127.0.0.1 6379 2 +sentinel resolve-hostnames yes +sentinel monitor redis-py-test redis 6379 2 sentinel down-after-milliseconds redis-py-test 5000 sentinel failover-timeout redis-py-test 60000 sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/redis/_cache.py b/redis/_cache.py index 7acfdde3e7..90288383d6 100644 --- a/redis/_cache.py +++ b/redis/_cache.py @@ -4,14 +4,20 @@ from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict from enum import Enum -from typing import List +from typing import List, Sequence, Union from redis.typing import KeyT, ResponseT -DEFAULT_EVICTION_POLICY = "lru" +class EvictionPolicy(Enum): + LRU = "lru" + LFU = "lfu" + RANDOM = "random" -DEFAULT_BLACKLIST = [ + +DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU + +DEFAULT_DENY_LIST = [ "BF.CARD", "BF.DEBUG", "BF.EXISTS", @@ -71,8 +77,7 @@ "TTL", ] - -DEFAULT_WHITELIST = [ +DEFAULT_ALLOW_LIST = [ "BITCOUNT", "BITFIELD_RO", "BITPOS", @@ -155,12 +160,6 @@ _ACCESS_COUNT = "access_count" -class EvictionPolicy(Enum): - LRU = "lru" - LFU = "lfu" - RANDOM = "random" - - class AbstractCache(ABC): """ An abstract base class for client caching implementations. @@ -168,19 +167,24 @@ class AbstractCache(ABC): """ @abstractmethod - def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + def set( + self, + command: Union[str, Sequence[str]], + response: ResponseT, + keys_in_command: List[KeyT], + ): pass @abstractmethod - def get(self, command: str) -> ResponseT: + def get(self, command: Union[str, Sequence[str]]) -> ResponseT: pass @abstractmethod - def delete_command(self, command: str): + def delete_command(self, command: Union[str, Sequence[str]]): pass @abstractmethod - def delete_many(self, commands): + def delete_commands(self, commands: List[Union[str, Sequence[str]]]): pass @abstractmethod @@ -215,7 +219,6 @@ def __init__( max_size: int = 10000, ttl: int = 0, eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, - **kwargs, ): self.max_size = max_size self.ttl = ttl @@ -224,12 +227,17 @@ def __init__( self.key_commands_map = defaultdict(set) self.commands_ttl_list = [] - def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + def set( + self, + command: Union[str, Sequence[str]], + response: ResponseT, + keys_in_command: List[KeyT], + ): """ Set a redis command and its response in the cache. Args: - command (str): The redis command. + command (Union[str, Sequence[str]]): The redis command. response (ResponseT): The response associated with the command. keys_in_command (List[KeyT]): The list of keys used in the command. """ @@ -244,12 +252,12 @@ def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): self._update_key_commands_map(keys_in_command, command) self.commands_ttl_list.append(command) - def get(self, command: str) -> ResponseT: + def get(self, command: Union[str, Sequence[str]]) -> ResponseT: """ Get the response for a redis command from the cache. Args: - command (str): The redis command. + command (Union[str, Sequence[str]]): The redis command. Returns: ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa @@ -261,12 +269,12 @@ def get(self, command: str) -> ResponseT: self._update_access(command) return copy.deepcopy(self.cache[command]["response"]) - def delete_command(self, command: str): + def delete_command(self, command: Union[str, Sequence[str]]): """ Delete a redis command and its metadata from the cache. Args: - command (str): The redis command to be deleted. + command (Union[str, Sequence[str]]): The redis command to be deleted. """ if command in self.cache: keys_in_command = self.cache[command].get("keys") @@ -274,8 +282,16 @@ def delete_command(self, command: str): self.commands_ttl_list.remove(command) del self.cache[command] - def delete_many(self, commands): - pass + def delete_commands(self, commands: List[Union[str, Sequence[str]]]): + """ + Delete multiple commands and their metadata from the cache. + + Args: + commands (List[Union[str, Sequence[str]]]): The list of commands to be + deleted. + """ + for command in commands: + self.delete_command(command) def flush(self): """Clear the entire cache, removing all redis commands and metadata.""" @@ -283,12 +299,12 @@ def flush(self): self.key_commands_map.clear() self.commands_ttl_list = [] - def _is_expired(self, command: str) -> bool: + def _is_expired(self, command: Union[str, Sequence[str]]) -> bool: """ Check if a redis command has expired based on its time-to-live. Args: - command (str): The redis command. + command (Union[str, Sequence[str]]): The redis command. Returns: bool: True if the command has expired, False otherwise. @@ -297,56 +313,60 @@ def _is_expired(self, command: str) -> bool: return False return time.monotonic() - self.cache[command]["ctime"] > self.ttl - def _update_access(self, command: str): + def _update_access(self, command: Union[str, Sequence[str]]): """ Update the access information for a redis command based on the eviction policy. Args: - command (str): The redis command. + command (Union[str, Sequence[str]]): The redis command. """ - if self.eviction_policy == EvictionPolicy.LRU.value: + if self.eviction_policy == EvictionPolicy.LRU: self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.LFU.value: + elif self.eviction_policy == EvictionPolicy.LFU: self.cache[command]["access_count"] = ( self.cache.get(command, {}).get("access_count", 0) + 1 ) self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.RANDOM.value: + elif self.eviction_policy == EvictionPolicy.RANDOM: pass # Random eviction doesn't require updates def _evict(self): """Evict a redis command from the cache based on the eviction policy.""" if self._is_expired(self.commands_ttl_list[0]): self.delete_command(self.commands_ttl_list[0]) - elif self.eviction_policy == EvictionPolicy.LRU.value: + elif self.eviction_policy == EvictionPolicy.LRU: self.cache.popitem(last=False) - elif self.eviction_policy == EvictionPolicy.LFU.value: + elif self.eviction_policy == EvictionPolicy.LFU: min_access_command = min( self.cache, key=lambda k: self.cache[k].get("access_count", 0) ) self.cache.pop(min_access_command) - elif self.eviction_policy == EvictionPolicy.RANDOM.value: + elif self.eviction_policy == EvictionPolicy.RANDOM: random_command = random.choice(list(self.cache.keys())) self.cache.pop(random_command) - def _update_key_commands_map(self, keys: List[KeyT], command: str): + def _update_key_commands_map( + self, keys: List[KeyT], command: Union[str, Sequence[str]] + ): """ Update the key_commands_map with command that uses the keys. Args: keys (List[KeyT]): The list of keys used in the command. - command (str): The redis command. + command (Union[str, Sequence[str]]): The redis command. """ for key in keys: self.key_commands_map[key].add(command) - def _del_key_commands_map(self, keys: List[KeyT], command: str): + def _del_key_commands_map( + self, keys: List[KeyT], command: Union[str, Sequence[str]] + ): """ Remove a redis command from the key_commands_map. Args: keys (List[KeyT]): The list of keys used in the redis command. - command (str): The redis command. + command (Union[str, Sequence[str]]): The redis command. """ for key in keys: self.key_commands_map[key].remove(command) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index c4c6c51a1e..1c71561ff7 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -25,9 +25,9 @@ ) from redis._cache import ( - DEFAULT_BLACKLIST, + DEFAULT_ALLOW_LIST, + DEFAULT_DENY_LIST, DEFAULT_EVICTION_POLICY, - DEFAULT_WHITELIST, AbstractCache, ) from redis._parsers.helpers import ( @@ -242,8 +242,8 @@ def __init__( cache_max_size: int = 100, cache_ttl: int = 0, cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_blacklist: List[str] = DEFAULT_BLACKLIST, - cache_whitelist: List[str] = DEFAULT_WHITELIST, + cache_deny_list: List[str] = DEFAULT_DENY_LIST, + cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Redis client. @@ -298,8 +298,8 @@ def __init__( "cache_max_size": cache_max_size, "cache_ttl": cache_ttl, "cache_policy": cache_policy, - "cache_blacklist": cache_blacklist, - "cache_whitelist": cache_whitelist, + "cache_deny_list": cache_deny_list, + "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -636,7 +636,8 @@ async def execute_command(self, *args, **options): ), lambda error: self._disconnect_raise(conn, error), ) - conn._add_to_local_cache(args, response, keys) + if keys: + conn._add_to_local_cache(args, response, keys) return response finally: if self.single_connection_client: @@ -671,31 +672,22 @@ async def parse_response( return response def flush_cache(self): - try: - if self.connection: - self.connection.client_cache.flush() - else: - self.connection_pool.flush_cache() - except AttributeError: - pass + if self.connection: + self.connection.flush_cache() + else: + self.connection_pool.flush_cache() def delete_command_from_cache(self, command): - try: - if self.connection: - self.connection.client_cache.delete_command(command) - else: - self.connection_pool.delete_command_from_cache(command) - except AttributeError: - pass + if self.connection: + self.connection.delete_command_from_cache(command) + else: + self.connection_pool.delete_command_from_cache(command) def invalidate_key_from_cache(self, key): - try: - if self.connection: - self.connection.client_cache.invalidate_key(key) - else: - self.connection_pool.invalidate_key_from_cache(key) - except AttributeError: - pass + if self.connection: + self.connection.invalidate_key_from_cache(key) + else: + self.connection_pool.invalidate_key_from_cache(key) StrictRedis = Redis diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index cffc268c80..ebe5567918 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -20,9 +20,9 @@ ) from redis._cache import ( - DEFAULT_BLACKLIST, + DEFAULT_ALLOW_LIST, + DEFAULT_DENY_LIST, DEFAULT_EVICTION_POLICY, - DEFAULT_WHITELIST, AbstractCache, ) from redis._parsers import AsyncCommandsParser, Encoder @@ -280,8 +280,8 @@ def __init__( cache_max_size: int = 100, cache_ttl: int = 0, cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_blacklist: List[str] = DEFAULT_BLACKLIST, - cache_whitelist: List[str] = DEFAULT_WHITELIST, + cache_deny_list: List[str] = DEFAULT_DENY_LIST, + cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ) -> None: if db: raise RedisClusterException( @@ -331,8 +331,8 @@ def __init__( "cache_max_size": cache_max_size, "cache_ttl": cache_ttl, "cache_policy": cache_policy, - "cache_blacklist": cache_blacklist, - "cache_whitelist": cache_whitelist, + "cache_deny_list": cache_deny_list, + "cache_allow_list": cache_allow_list, } if ssl: @@ -936,6 +936,18 @@ def lock( thread_local=thread_local, ) + def flush_cache(self): + if self.nodes_manager: + self.nodes_manager.flush_cache() + + def delete_command_from_cache(self, command): + if self.nodes_manager: + self.nodes_manager.delete_command_from_cache(command) + + def invalidate_key_from_cache(self, key): + if self.nodes_manager: + self.nodes_manager.invalidate_key_from_cache(key) + class ClusterNode: """ @@ -1075,7 +1087,8 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Read response try: response = await self.parse_response(connection, args[0], **kwargs) - connection._add_to_local_cache(args, response, keys) + if keys: + connection._add_to_local_cache(args, response, keys) return response finally: # Release connection @@ -1106,6 +1119,18 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: return ret + def flush_cache(self): + for connection in self._connections: + connection.flush_cache() + + def delete_command_from_cache(self, command): + for connection in self._connections: + connection.delete_command_from_cache(command) + + def invalidate_key_from_cache(self, key): + for connection in self._connections: + connection.invalidate_key_from_cache(key) + class NodesManager: __slots__ = ( @@ -1391,6 +1416,18 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port + def flush_cache(self): + for node in self.nodes_cache.values(): + node.flush_cache() + + def delete_command_from_cache(self, command): + for node in self.nodes_cache.values(): + node.delete_command_from_cache(command) + + def invalidate_key_from_cache(self, key): + for node in self.nodes_cache.values(): + node.invalidate_key_from_cache(key) + class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index e03e8c63da..8e186eab54 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -50,9 +50,9 @@ from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes from .._cache import ( - DEFAULT_BLACKLIST, + DEFAULT_ALLOW_LIST, + DEFAULT_DENY_LIST, DEFAULT_EVICTION_POLICY, - DEFAULT_WHITELIST, AbstractCache, _LocalCache, ) @@ -119,8 +119,8 @@ class AbstractConnection: "ssl_context", "protocol", "client_cache", - "cache_blacklist", - "cache_whitelist", + "cache_deny_list", + "cache_allow_list", "_reader", "_writer", "_parser", @@ -160,8 +160,8 @@ def __init__( cache_max_size: int = 10000, cache_ttl: int = 0, cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_blacklist: List[str] = DEFAULT_BLACKLIST, - cache_whitelist: List[str] = DEFAULT_WHITELIST, + cache_deny_list: List[str] = DEFAULT_DENY_LIST, + cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): if (username or password) and credential_provider is not None: raise DataError( @@ -229,8 +229,8 @@ def __init__( raise RedisError( "client caching is only supported with protocol version 3 or higher" ) - self.cache_blacklist = cache_blacklist - self.cache_whitelist = cache_whitelist + self.cache_deny_list = cache_deny_list + self.cache_allow_list = cache_allow_list def __del__(self, _warnings: Any = warnings): # For some reason, the individual streams don't get properly garbage @@ -695,7 +695,7 @@ def _cache_invalidation_process( and the second string is the list of keys to invalidate. (if the list of keys is None, then all keys are invalidated) """ - if data[1] is not None: + if data[1] is None: self.client_cache.flush() else: for key in data[1]: @@ -707,8 +707,8 @@ async def _get_from_local_cache(self, command: str): """ if ( self.client_cache is None - or command[0] in self.cache_blacklist - or command[0] not in self.cache_whitelist + or command[0] in self.cache_deny_list + or command[0] not in self.cache_allow_list ): return None while not self._socket_is_empty(): @@ -724,11 +724,23 @@ def _add_to_local_cache( """ if ( self.client_cache is not None - and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) - and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) + and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) ): self.client_cache.set(command, response, keys) + def flush_cache(self): + if self.client_cache: + self.client_cache.flush() + + def delete_command_from_cache(self, command): + if self.client_cache: + self.client_cache.delete_command(command) + + def invalidate_key_from_cache(self, key): + if self.client_cache: + self.client_cache.invalidate_key(key) + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -1251,33 +1263,18 @@ def set_retry(self, retry: "Retry") -> None: def flush_cache(self): connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - try: - connection.client_cache.flush() - except AttributeError: - # cache is not enabled - pass + connection.flush_cache() def delete_command_from_cache(self, command: str): connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - try: - connection.client_cache.delete_command(command) - except AttributeError: - # cache is not enabled - pass + connection.delete_command_from_cache(command) def invalidate_key_from_cache(self, key: str): connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - try: - connection.client_cache.invalidate_key(key) - except AttributeError: - # cache is not enabled - pass + connection.invalidate_key_from_cache(key) class BlockingConnectionPool(ConnectionPool): diff --git a/redis/client.py b/redis/client.py index c45f5da790..7a76e3b269 100755 --- a/redis/client.py +++ b/redis/client.py @@ -7,9 +7,9 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union from redis._cache import ( - DEFAULT_BLACKLIST, + DEFAULT_ALLOW_LIST, + DEFAULT_DENY_LIST, DEFAULT_EVICTION_POLICY, - DEFAULT_WHITELIST, AbstractCache, ) from redis._parsers.encoders import Encoder @@ -220,8 +220,8 @@ def __init__( cache_max_size: int = 10000, cache_ttl: int = 0, cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_blacklist: List[str] = DEFAULT_BLACKLIST, - cache_whitelist: List[str] = DEFAULT_WHITELIST, + cache_deny_list: List[str] = DEFAULT_DENY_LIST, + cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ) -> None: """ Initialize a new Redis client. @@ -278,8 +278,8 @@ def __init__( "cache_max_size": cache_max_size, "cache_ttl": cache_ttl, "cache_policy": cache_policy, - "cache_blacklist": cache_blacklist, - "cache_whitelist": cache_whitelist, + "cache_deny_list": cache_deny_list, + "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -570,7 +570,8 @@ def execute_command(self, *args, **options): ), lambda error: self._disconnect_raise(conn, error), ) - conn._add_to_local_cache(args, response, keys) + if keys: + conn._add_to_local_cache(args, response, keys) return response finally: if not self.connection: @@ -597,31 +598,22 @@ def parse_response(self, connection, command_name, **options): return response def flush_cache(self): - try: - if self.connection: - self.connection.client_cache.flush() - else: - self.connection_pool.flush_cache() - except AttributeError: - pass + if self.connection: + self.connection.flush_cache() + else: + self.connection_pool.flush_cache() def delete_command_from_cache(self, command): - try: - if self.connection: - self.connection.client_cache.delete_command(command) - else: - self.connection_pool.delete_command_from_cache(command) - except AttributeError: - pass + if self.connection: + self.connection.delete_command_from_cache(command) + else: + self.connection_pool.delete_command_from_cache(command) def invalidate_key_from_cache(self, key): - try: - if self.connection: - self.connection.client_cache.invalidate_key(key) - else: - self.connection_pool.invalidate_key_from_cache(key) - except AttributeError: - pass + if self.connection: + self.connection.invalidate_key_from_cache(key) + else: + self.connection_pool.invalidate_key_from_cache(key) StrictRedis = Redis diff --git a/redis/cluster.py b/redis/cluster.py index cfe902115e..e792d51867 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -172,8 +172,8 @@ def parse_cluster_myshardid(resp, **options): "cache_max_size", "cache_ttl", "cache_policy", - "cache_blacklist", - "cache_whitelist", + "cache_deny_list", + "cache_allow_list", ) KWARGS_DISABLED_KEYS = ("host", "port") @@ -1164,7 +1164,8 @@ def _execute_command(self, target_node, *args, **kwargs): response = self.cluster_response_callbacks[command]( response, **kwargs ) - connection._add_to_local_cache(args, response, keys) + if keys: + connection._add_to_local_cache(args, response, keys) return response except AuthenticationError: raise @@ -1265,6 +1266,18 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) + def flush_cache(self): + if self.nodes_manager: + self.nodes_manager.flush_cache() + + def delete_command_from_cache(self, command): + if self.nodes_manager: + self.nodes_manager.delete_command_from_cache(command) + + def invalidate_key_from_cache(self, key): + if self.nodes_manager: + self.nodes_manager.invalidate_key_from_cache(key) + class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1293,6 +1306,18 @@ def __del__(self): if self.redis_connection is not None: self.redis_connection.close() + def flush_cache(self): + if self.redis_connection is not None: + self.redis_connection.flush_cache() + + def delete_command_from_cache(self, command): + if self.redis_connection is not None: + self.redis_connection.delete_command_from_cache(command) + + def invalidate_key_from_cache(self, key): + if self.redis_connection is not None: + self.redis_connection.invalidate_key_from_cache(key) + class LoadBalancer: """ @@ -1659,6 +1684,18 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port + def flush_cache(self): + for node in self.nodes_cache.values(): + node.flush_cache() + + def delete_command_from_cache(self, command): + for node in self.nodes_cache.values(): + node.delete_command_from_cache(command) + + def invalidate_key_from_cache(self, key): + for node in self.nodes_cache.values(): + node.invalidate_key_from_cache(key) + class ClusterPubSub(PubSub): """ diff --git a/redis/connection.py b/redis/connection.py index 4912f89f63..2cefd40d34 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,13 +9,13 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Sequence, Type, Union from urllib.parse import parse_qs, unquote, urlparse from ._cache import ( - DEFAULT_BLACKLIST, + DEFAULT_ALLOW_LIST, + DEFAULT_DENY_LIST, DEFAULT_EVICTION_POLICY, - DEFAULT_WHITELIST, AbstractCache, _LocalCache, ) @@ -162,8 +162,8 @@ def __init__( cache_max_size: int = 10000, cache_ttl: int = 0, cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_blacklist: List[str] = DEFAULT_BLACKLIST, - cache_whitelist: List[str] = DEFAULT_WHITELIST, + cache_deny_list: List[str] = DEFAULT_DENY_LIST, + cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Connection. @@ -239,8 +239,8 @@ def __init__( raise RedisError( "client caching is only supported with protocol version 3 or higher" ) - self.cache_blacklist = cache_blacklist - self.cache_whitelist = cache_whitelist + self.cache_deny_list = cache_deny_list + self.cache_allow_list = cache_allow_list def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -623,14 +623,14 @@ def _cache_invalidation_process( for key in data[1]: self.client_cache.invalidate_key(str_if_bytes(key)) - def _get_from_local_cache(self, command: str): + def _get_from_local_cache(self, command: Sequence[str]): """ If the command is in the local cache, return the response """ if ( self.client_cache is None - or command[0] in self.cache_blacklist - or command[0] not in self.cache_whitelist + or command[0] in self.cache_deny_list + or command[0] not in self.cache_allow_list ): return None while self.can_read(): @@ -638,7 +638,7 @@ def _get_from_local_cache(self, command: str): return self.client_cache.get(command) def _add_to_local_cache( - self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + self, command: Sequence[str], response: ResponseT, keys: List[KeysT] ): """ Add the command and response to the local cache if the command @@ -646,11 +646,23 @@ def _add_to_local_cache( """ if ( self.client_cache is not None - and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) - and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) + and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) ): self.client_cache.set(command, response, keys) + def flush_cache(self): + if self.client_cache: + self.client_cache.flush() + + def delete_command_from_cache(self, command: Union[str, Sequence[str]]): + if self.client_cache: + self.client_cache.delete_command(command) + + def invalidate_key_from_cache(self, key: KeysT): + if self.client_cache: + self.client_cache.invalidate_key(key) + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -1279,37 +1291,22 @@ def flush_cache(self): self._checkpid() with self._lock: connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - try: - connection.client_cache.flush() - except AttributeError: - # cache is not enabled - pass + connection.flush_cache() def delete_command_from_cache(self, command: str): self._checkpid() with self._lock: connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - try: - connection.client_cache.delete_command(command) - except AttributeError: - # cache is not enabled - pass + connection.delete_command_from_cache(command) def invalidate_key_from_cache(self, key: str): self._checkpid() with self._lock: connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - try: - connection.client_cache.invalidate_key(key) - except AttributeError: - # cache is not enabled - pass + connection.invalidate_key_from_cache(key) class BlockingConnectionPool(ConnectionPool): diff --git a/tests/conftest.py b/tests/conftest.py index 8786e2b9f0..e783b6e8f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import pytest import redis from packaging.version import Version +from redis import Sentinel from redis.backoff import NoBackoff from redis.connection import Connection, parse_url from redis.exceptions import RedisClusterException @@ -105,6 +106,19 @@ def pytest_addoption(parser): "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" ) + parser.addoption( + "--sentinels", + action="store", + default="localhost:26379,localhost:26380,localhost:26381", + help="Comma-separated list of sentinel IPs and ports", + ) + parser.addoption( + "--master-service", + action="store", + default="redis-py-test", + help="Name of the Redis master service that the sentinels are monitoring", + ) + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) @@ -352,6 +366,34 @@ def sslclient(request): yield client +@pytest.fixture() +def sentinel_setup(local_cache, request): + sentinel_ips = request.config.getoption("--sentinels") + sentinel_endpoints = [ + (ip.strip(), int(port.strip())) + for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) + ] + kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} + sentinel = Sentinel( + sentinel_endpoints, + socket_timeout=0.1, + client_cache=local_cache, + protocol=3, + **kwargs, + ) + yield sentinel + for s in sentinel.sentinels: + s.close() + + +@pytest.fixture() +def master(request, sentinel_setup): + master_service = request.config.getoption("--master-service") + master = sentinel_setup.master_for(master_service) + yield master + master.close() + + def _gen_cluster_mock_resp(r, response): connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index c6afec5af6..cff239fa11 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -7,6 +7,7 @@ import redis.asyncio as redis from packaging.version import Version from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser +from redis.asyncio import Sentinel from redis.asyncio.client import Monitor from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry @@ -136,6 +137,34 @@ async def decoded_r(create_redis): return await create_redis(decode_responses=True) +@pytest_asyncio.fixture() +async def sentinel_setup(local_cache, request): + sentinel_ips = request.config.getoption("--sentinels") + sentinel_endpoints = [ + (ip.strip(), int(port.strip())) + for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) + ] + kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} + sentinel = Sentinel( + sentinel_endpoints, + socket_timeout=0.1, + client_cache=local_cache, + protocol=3, + **kwargs, + ) + yield sentinel + for s in sentinel.sentinels: + await s.aclose() + + +@pytest_asyncio.fixture() +async def master(request, sentinel_setup): + master_service = request.config.getoption("--master-service") + master = sentinel_setup.master_for(master_service) + yield master + await master.aclose() + + def _gen_cluster_mock_resp(r, response): connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index 4762bb7c05..7a7f881ce2 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -2,7 +2,7 @@ import pytest import pytest_asyncio -from redis._cache import _LocalCache +from redis._cache import EvictionPolicy, _LocalCache from redis.utils import HIREDIS_AVAILABLE @@ -14,10 +14,15 @@ async def r(request, create_redis): yield r, cache +@pytest_asyncio.fixture() +async def local_cache(): + yield _LocalCache() + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") class TestLocalCache: - @pytest.mark.onlynoncluster @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) + @pytest.mark.onlynoncluster async def test_get_from_cache(self, r, r2): r, cache = r # add key to redis @@ -36,7 +41,7 @@ async def test_get_from_cache(self, r, r2): assert await r.get("foo") == b"barbar" @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True) - async def test_cache_max_size(self, r): + async def test_cache_lru_eviction(self, r): r, cache = r # add 3 keys to redis await r.set("foo", "bar") @@ -71,7 +76,9 @@ async def test_cache_ttl(self, r): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize( - "r", [{"cache": _LocalCache(max_size=3, eviction_policy="lfu")}], indirect=True + "r", + [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], + indirect=True, ) async def test_cache_lfu_eviction(self, r): r, cache = r @@ -95,12 +102,12 @@ async def test_cache_lfu_eviction(self, r): assert cache.get(("GET", "foo")) == b"bar" assert cache.get(("GET", "foo2")) is None - @pytest.mark.onlynoncluster @pytest.mark.parametrize( "r", [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], indirect=True, ) + @pytest.mark.onlynoncluster async def test_cache_decode_response(self, r): r, cache = r await r.set("foo", "bar") @@ -119,10 +126,10 @@ async def test_cache_decode_response(self, r): @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"cache_blacklist": ["LLEN"]}}], + [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], indirect=True, ) - async def test_cache_blacklist(self, r): + async def test_cache_deny_list(self, r): r, cache = r # add list to redis await r.lpush("mylist", "foo", "bar", "baz") @@ -131,6 +138,20 @@ async def test_cache_blacklist(self, r): assert cache.get(("LLEN", "mylist")) is None assert cache.get(("LINDEX", "mylist", 1)) == b"bar" + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], + indirect=True, + ) + async def test_cache_allow_list(self, r): + r, cache = r + # add list to redis + await r.lpush("mylist", "foo", "bar", "baz") + assert await r.llen("mylist") == 3 + assert await r.lindex("mylist", 1) == b"bar" + assert cache.get(("LLEN", "mylist")) == 3 + assert cache.get(("LINDEX", "mylist", 1)) is None + @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) async def test_cache_return_copy(self, r): r, cache = r @@ -142,12 +163,12 @@ async def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] - @pytest.mark.onlynoncluster @pytest.mark.parametrize( "r", [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], indirect=True, ) + @pytest.mark.onlynoncluster async def test_csc_not_cause_disconnects(self, r): r, cache = r id1 = await r.client_id() @@ -184,6 +205,99 @@ async def test_csc_not_cause_disconnects(self, r): id4 = await r.client_id() assert id1 == id2 == id3 == id4 + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_execute_command_keys_provided(self, r): + r, cache = r + assert await r.execute_command("SET", "b", "2") is True + assert await r.execute_command("GET", "b", keys=["b"]) == "2" + assert cache.get(("GET", "b")) == "2" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_execute_command_keys_not_provided(self, r): + r, cache = r + assert await r.execute_command("SET", "b", "2") is True + assert ( + await r.execute_command("GET", "b") == "2" + ) # keys not provided, not cached + assert cache.get(("GET", "b")) is None + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_delete_one_command(self, r): + r, cache = r + assert await r.mset({"a{a}": 1, "b{a}": 1}) is True + assert await r.set("c", 1) is True + assert await r.mget("a{a}", "b{a}") == ["1", "1"] + assert await r.get("c") == "1" + # values should be in local cache + assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] + assert cache.get(("GET", "c")) == "1" + # delete one command from the cache + r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) + # the other command is still in the local cache anymore + assert cache.get(("MGET", "a{a}", "b{a}")) is None + assert cache.get(("GET", "c")) == "1" + # get from redis + assert await r.mget("a{a}", "b{a}") == ["1", "1"] + assert await r.get("c") == "1" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_invalidate_key(self, r): + r, cache = r + assert await r.mset({"a{a}": 1, "b{a}": 1}) is True + assert await r.set("c", 1) is True + assert await r.mget("a{a}", "b{a}") == ["1", "1"] + assert await r.get("c") == "1" + # values should be in local cache + assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] + assert cache.get(("GET", "c")) == "1" + # invalidate one key from the cache + r.invalidate_key_from_cache("b{a}") + # one other command is still in the local cache anymore + assert cache.get(("MGET", "a{a}", "b{a}")) is None + assert cache.get(("GET", "c")) == "1" + # get from redis + assert await r.mget("a{a}", "b{a}") == ["1", "1"] + assert await r.get("c") == "1" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_flush_entire_cache(self, r): + r, cache = r + assert await r.mset({"a{a}": 1, "b{a}": 1}) is True + assert await r.set("c", 1) is True + assert await r.mget("a{a}", "b{a}") == ["1", "1"] + assert await r.get("c") == "1" + # values should be in local cache + assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] + assert cache.get(("GET", "c")) == "1" + # flush the local cache + r.flush_cache() + # the commands are not in the local cache anymore + assert cache.get(("MGET", "a{a}", "b{a}")) is None + assert cache.get(("GET", "c")) is None + # get from redis + assert await r.mget("a{a}", "b{a}") == ["1", "1"] + assert await r.get("c") == "1" + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster @@ -228,3 +342,67 @@ async def test_cache_decode_response(self, r): assert cache.get(("GET", "foo")) is None # get key from redis assert await r.get("foo") == "barbar" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_execute_command_keys_provided(self, r): + r, cache = r + assert await r.execute_command("SET", "b", "2") is True + assert await r.execute_command("GET", "b", keys=["b"]) == "2" + assert cache.get(("GET", "b")) == "2" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_execute_command_keys_not_provided(self, r): + r, cache = r + assert await r.execute_command("SET", "b", "2") is True + assert ( + await r.execute_command("GET", "b") == "2" + ) # keys not provided, not cached + assert cache.get(("GET", "b")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestSentinelLocalCache: + + async def test_get_from_cache(self, local_cache, master): + await master.set("foo", "bar") + # get key from redis and save in local cache + assert await master.get("foo") == b"bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + await master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert await master.get("foo") == b"barbar" + + @pytest.mark.parametrize( + "sentinel_setup", + [{"kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_cache_decode_response(self, local_cache, sentinel_setup, master): + await master.set("foo", "bar") + # get key from redis and save in local cache + assert await master.get("foo") == "bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + await master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert await master.get("foo") == "barbar" diff --git a/tests/test_cache.py b/tests/test_cache.py index dd33afd23e..022364e87a 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,8 +1,13 @@ import time +from collections import defaultdict +from typing import List, Sequence, Union +import cachetools import pytest import redis -from redis._cache import _LocalCache +from redis import RedisError +from redis._cache import AbstractCache, EvictionPolicy, _LocalCache +from redis.typing import KeyT, ResponseT from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client @@ -11,17 +16,28 @@ def r(request): cache = request.param.get("cache") kwargs = request.param.get("kwargs", {}) + protocol = request.param.get("protocol", 3) + single_connection_client = request.param.get("single_connection_client", False) with _get_client( - redis.Redis, request, protocol=3, client_cache=cache, **kwargs + redis.Redis, + request, + single_connection_client=single_connection_client, + protocol=protocol, + client_cache=cache, + **kwargs, ) as client: yield client, cache - # client.flushdb() + + +@pytest.fixture() +def local_cache(): + return _LocalCache() @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") class TestLocalCache: - @pytest.mark.onlynoncluster @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) + @pytest.mark.onlynoncluster def test_get_from_cache(self, r, r2): r, cache = r # add key to redis @@ -39,8 +55,12 @@ def test_get_from_cache(self, r, r2): # get key from redis assert r.get("foo") == b"barbar" - @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True) - def test_cache_max_size(self, r): + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(max_size=3)}], + indirect=True, + ) + def test_cache_lru_eviction(self, r): r, cache = r # add 3 keys to redis r.set("foo", "bar") @@ -75,7 +95,9 @@ def test_cache_ttl(self, r): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize( - "r", [{"cache": _LocalCache(max_size=3, eviction_policy="lfu")}], indirect=True + "r", + [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], + indirect=True, ) def test_cache_lfu_eviction(self, r): r, cache = r @@ -99,12 +121,12 @@ def test_cache_lfu_eviction(self, r): assert cache.get(("GET", "foo")) == b"bar" assert cache.get(("GET", "foo2")) is None - @pytest.mark.onlynoncluster @pytest.mark.parametrize( "r", [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], indirect=True, ) + @pytest.mark.onlynoncluster def test_cache_decode_response(self, r): r, cache = r r.set("foo", "bar") @@ -123,10 +145,10 @@ def test_cache_decode_response(self, r): @pytest.mark.parametrize( "r", - [{"cache": _LocalCache(), "kwargs": {"cache_blacklist": ["LLEN"]}}], + [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], indirect=True, ) - def test_cache_blacklist(self, r): + def test_cache_deny_list(self, r): r, cache = r # add list to redis r.lpush("mylist", "foo", "bar", "baz") @@ -135,6 +157,19 @@ def test_cache_blacklist(self, r): assert cache.get(("LLEN", "mylist")) is None assert cache.get(("LINDEX", "mylist", 1)) == b"bar" + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], + indirect=True, + ) + def test_cache_allow_list(self, r): + r, cache = r + r.lpush("mylist", "foo", "bar", "baz") + assert r.llen("mylist") == 3 + assert r.lindex("mylist", 1) == b"bar" + assert cache.get(("LLEN", "mylist")) == 3 + assert cache.get(("LINDEX", "mylist", 1)) is None + @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) def test_cache_return_copy(self, r): r, cache = r @@ -146,12 +181,12 @@ def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] - @pytest.mark.onlynoncluster @pytest.mark.parametrize( "r", [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], indirect=True, ) + @pytest.mark.onlynoncluster def test_csc_not_cause_disconnects(self, r): r, cache = r id1 = r.client_id() @@ -189,6 +224,198 @@ def test_csc_not_cause_disconnects(self, r): id4 = r.client_id() assert id1 == id2 == id3 == id4 + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_multiple_commands_same_key(self, r): + r, cache = r + r.mset({"a": 1, "b": 1}) + assert r.mget("a", "b") == ["1", "1"] + # value should be in local cache + assert cache.get(("MGET", "a", "b")) == ["1", "1"] + # set only one key + r.set("a", 2) + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert cache.get(("MGET", "a", "b")) is None + # get from redis + assert r.mget("a", "b") == ["2", "1"] + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_delete_one_command(self, r): + r, cache = r + r.mset({"a{a}": 1, "b{a}": 1}) + r.set("c", 1) + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + # values should be in local cache + assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] + assert cache.get(("GET", "c")) == "1" + # delete one command from the cache + r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) + # the other command is still in the local cache anymore + assert cache.get(("MGET", "a{a}", "b{a}")) is None + assert cache.get(("GET", "c")) == "1" + # get from redis + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_delete_several_commands(self, r): + r, cache = r + r.mset({"a{a}": 1, "b{a}": 1}) + r.set("c", 1) + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + # values should be in local cache + assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] + assert cache.get(("GET", "c")) == "1" + # delete the commands from the cache + cache.delete_commands([("MGET", "a{a}", "b{a}"), ("GET", "c")]) + # the commands are not in the local cache anymore + assert cache.get(("MGET", "a{a}", "b{a}")) is None + assert cache.get(("GET", "c")) is None + # get from redis + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_invalidate_key(self, r): + r, cache = r + r.mset({"a{a}": 1, "b{a}": 1}) + r.set("c", 1) + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + # values should be in local cache + assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] + assert cache.get(("GET", "c")) == "1" + # invalidate one key from the cache + r.invalidate_key_from_cache("b{a}") + # one other command is still in the local cache anymore + assert cache.get(("MGET", "a{a}", "b{a}")) is None + assert cache.get(("GET", "c")) == "1" + # get from redis + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_flush_entire_cache(self, r): + r, cache = r + r.mset({"a{a}": 1, "b{a}": 1}) + r.set("c", 1) + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + # values should be in local cache + assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] + assert cache.get(("GET", "c")) == "1" + # flush the local cache + r.flush_cache() + # the commands are not in the local cache anymore + assert cache.get(("MGET", "a{a}", "b{a}")) is None + assert cache.get(("GET", "c")) is None + # get from redis + assert r.mget("a{a}", "b{a}") == ["1", "1"] + assert r.get("c") == "1" + + @pytest.mark.onlynoncluster + def test_cache_not_available_with_resp2(self, request): + with pytest.raises(RedisError) as e: + _get_client(redis.Redis, request, protocol=2, client_cache=_LocalCache()) + assert "protocol version 3 or higher" in str(e.value) + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_execute_command_args_not_split(self, r): + r, cache = r + assert r.execute_command("SET a 1") == "OK" + assert r.execute_command("GET a") == "1" + # "get a" is not whitelisted by default, the args should be separated + assert cache.get(("GET a",)) is None + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_execute_command_keys_provided(self, r): + r, cache = r + assert r.execute_command("SET", "b", "2") is True + assert r.execute_command("GET", "b", keys=["b"]) == "2" + assert cache.get(("GET", "b")) == "2" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_execute_command_keys_not_provided(self, r): + r, cache = r + assert r.execute_command("SET", "b", "2") is True + assert r.execute_command("GET", "b") == "2" # keys not provided, not cached + assert cache.get(("GET", "b")) is None + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "single_connection_client": True}], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_single_connection(self, r): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar" + + @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) + def test_get_from_cache_invalidate_via_get(self, r, r2): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # don't send any command to redis, just run another get + # it should process the invalidation in background + assert r.get("foo") == b"barbar" + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster @@ -233,3 +460,128 @@ def test_cache_decode_response(self, r): assert cache.get(("GET", "foo")) is None # get key from redis assert r.get("foo") == "barbar" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_execute_command_keys_provided(self, r): + r, cache = r + assert r.execute_command("SET", "b", "2") is True + assert r.execute_command("GET", "b", keys=["b"]) == "2" + assert cache.get(("GET", "b")) == "2" + + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_execute_command_keys_not_provided(self, r): + r, cache = r + assert r.execute_command("SET", "b", "2") is True + assert r.execute_command("GET", "b") == "2" # keys not provided, not cached + assert cache.get(("GET", "b")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestSentinelLocalCache: + + def test_get_from_cache(self, local_cache, master): + master.set("foo", "bar") + # get key from redis and save in local cache + assert master.get("foo") == b"bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert master.get("foo") == b"barbar" + + @pytest.mark.parametrize( + "sentinel_setup", + [{"kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_cache_decode_response(self, local_cache, sentinel_setup, master): + master.set("foo", "bar") + # get key from redis and save in local cache + assert master.get("foo") == "bar" + # get key from local cache + assert local_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + master.ping() + # the command is not in the local cache anymore + assert local_cache.get(("GET", "foo")) is None + # get key from redis + assert master.get("foo") == "barbar" + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestCustomCache: + class _CustomCache(AbstractCache): + def __init__(self): + self.responses = cachetools.LRUCache(maxsize=1000) + self.keys_to_commands = defaultdict(list) + self.commands_to_keys = defaultdict(list) + + def set( + self, + command: Union[str, Sequence[str]], + response: ResponseT, + keys_in_command: List[KeyT], + ): + self.responses[command] = response + for key in keys_in_command: + self.keys_to_commands[key].append(tuple(command)) + self.commands_to_keys[command].append(tuple(keys_in_command)) + + def get(self, command: Union[str, Sequence[str]]) -> ResponseT: + return self.responses.get(command) + + def delete_command(self, command: Union[str, Sequence[str]]): + self.responses.pop(command, None) + keys = self.commands_to_keys.pop(command, []) + for key in keys: + if command in self.keys_to_commands[key]: + self.keys_to_commands[key].remove(command) + + def delete_commands(self, commands: List[Union[str, Sequence[str]]]): + for command in commands: + self.delete_command(command) + + def flush(self): + self.responses.clear() + self.commands_to_keys.clear() + self.keys_to_commands.clear() + + def invalidate_key(self, key: KeyT): + commands = self.keys_to_commands.pop(key, []) + for command in commands: + self.delete_command(command) + + @pytest.mark.parametrize("r", [{"cache": _CustomCache()}], indirect=True) + def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar"