37
37
)
38
38
from redis .asyncio .lock import Lock
39
39
from redis .asyncio .retry import Retry
40
+ from redis .cache import (
41
+ DEFAULT_BLACKLIST ,
42
+ DEFAULT_EVICTION_POLICY ,
43
+ DEFAULT_WHITELIST ,
44
+ _LocalCache ,
45
+ )
40
46
from redis .client import (
41
47
EMPTY_RESPONSE ,
42
48
NEVER_DECODE ,
60
66
TimeoutError ,
61
67
WatchError ,
62
68
)
63
- from redis .typing import ChannelT , EncodableT , KeyT
69
+ from redis .typing import ChannelT , EncodableT , KeysT , KeyT , ResponseT
64
70
from redis .utils import (
65
71
HIREDIS_AVAILABLE ,
66
72
_set_info_logger ,
@@ -231,6 +237,13 @@ def __init__(
231
237
redis_connect_func = None ,
232
238
credential_provider : Optional [CredentialProvider ] = None ,
233
239
protocol : Optional [int ] = 2 ,
240
+ cache_enable : bool = False ,
241
+ client_cache : Optional [_LocalCache ] = None ,
242
+ cache_max_size : int = 100 ,
243
+ cache_ttl : int = 0 ,
244
+ cache_eviction_policy : str = DEFAULT_EVICTION_POLICY ,
245
+ cache_blacklist : List [str ] = DEFAULT_BLACKLIST ,
246
+ cache_whitelist : List [str ] = DEFAULT_WHITELIST ,
234
247
):
235
248
"""
236
249
Initialize a new Redis client.
@@ -336,6 +349,16 @@ def __init__(
336
349
# on a set of redis commands
337
350
self ._single_conn_lock = asyncio .Lock ()
338
351
352
+ self .client_cache = client_cache
353
+ if cache_enable :
354
+ self .client_cache = _LocalCache (
355
+ cache_max_size , cache_ttl , cache_eviction_policy
356
+ )
357
+ if self .client_cache is not None :
358
+ self .cache_blacklist = cache_blacklist
359
+ self .cache_whitelist = cache_whitelist
360
+ self .client_cache_initialized = False
361
+
339
362
def __repr__ (self ):
340
363
return (
341
364
f"<{ self .__class__ .__module__ } .{ self .__class__ .__name__ } "
@@ -350,6 +373,10 @@ async def initialize(self: _RedisT) -> _RedisT:
350
373
async with self ._single_conn_lock :
351
374
if self .connection is None :
352
375
self .connection = await self .connection_pool .get_connection ("_" )
376
+ if self .client_cache is not None :
377
+ self .connection ._parser .set_invalidation_push_handler (
378
+ self ._cache_invalidation_process
379
+ )
353
380
return self
354
381
355
382
def set_response_callback (self , command : str , callback : ResponseCallbackT ):
@@ -568,6 +595,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
568
595
close_connection_pool is None and self .auto_close_connection_pool
569
596
):
570
597
await self .connection_pool .disconnect ()
598
+ if self .client_cache :
599
+ self .client_cache .flush ()
571
600
572
601
@deprecated_function (version = "5.0.1" , reason = "Use aclose() instead" , name = "close" )
573
602
async def close (self , close_connection_pool : Optional [bool ] = None ) -> None :
@@ -596,29 +625,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
596
625
):
597
626
raise error
598
627
628
+ def _cache_invalidation_process (
629
+ self , data : List [Union [str , Optional [List [str ]]]]
630
+ ) -> None :
631
+ """
632
+ Invalidate (delete) all redis commands associated with a specific key.
633
+ `data` is a list of strings, where the first string is the invalidation message
634
+ and the second string is the list of keys to invalidate.
635
+ (if the list of keys is None, then all keys are invalidated)
636
+ """
637
+ if data [1 ] is not None :
638
+ for key in data [1 ]:
639
+ self .client_cache .invalidate (str_if_bytes (key ))
640
+ else :
641
+ self .client_cache .flush ()
642
+
643
+ async def _get_from_local_cache (self , command : str ):
644
+ """
645
+ If the command is in the local cache, return the response
646
+ """
647
+ if (
648
+ self .client_cache is None
649
+ or command [0 ] in self .cache_blacklist
650
+ or command [0 ] not in self .cache_whitelist
651
+ ):
652
+ return None
653
+ while not self .connection ._is_socket_empty ():
654
+ await self .connection .read_response (push_request = True )
655
+ return self .client_cache .get (command )
656
+
657
+ def _add_to_local_cache (
658
+ self , command : Tuple [str ], response : ResponseT , keys : List [KeysT ]
659
+ ):
660
+ """
661
+ Add the command and response to the local cache if the command
662
+ is allowed to be cached
663
+ """
664
+ if (
665
+ self .client_cache is not None
666
+ and (self .cache_blacklist == [] or command [0 ] not in self .cache_blacklist )
667
+ and (self .cache_whitelist == [] or command [0 ] in self .cache_whitelist )
668
+ ):
669
+ self .client_cache .set (command , response , keys )
670
+
671
+ def delete_from_local_cache (self , command : str ):
672
+ """
673
+ Delete the command from the local cache
674
+ """
675
+ try :
676
+ self .client_cache .delete (command )
677
+ except AttributeError :
678
+ pass
679
+
599
680
# COMMAND EXECUTION AND PROTOCOL PARSING
600
681
async def execute_command (self , * args , ** options ):
601
682
"""Execute a command and return a parsed response"""
602
683
await self .initialize ()
603
- options .pop ("keys" , None ) # the keys are used only for client side caching
604
- pool = self .connection_pool
605
684
command_name = args [0 ]
606
- conn = self .connection or await pool .get_connection (command_name , ** options )
685
+ keys = options .pop ("keys" , None ) # keys are used only for client side caching
686
+ response_from_cache = await self ._get_from_local_cache (args )
687
+ if response_from_cache is not None :
688
+ return response_from_cache
689
+ else :
690
+ pool = self .connection_pool
691
+ conn = self .connection or await pool .get_connection (command_name , ** options )
607
692
608
- if self .single_connection_client :
609
- await self ._single_conn_lock .acquire ()
610
- try :
611
- return await conn .retry .call_with_retry (
612
- lambda : self ._send_command_parse_response (
613
- conn , command_name , * args , ** options
614
- ),
615
- lambda error : self ._disconnect_raise (conn , error ),
616
- )
617
- finally :
618
693
if self .single_connection_client :
619
- self ._single_conn_lock .release ()
620
- if not self .connection :
621
- await pool .release (conn )
694
+ await self ._single_conn_lock .acquire ()
695
+ try :
696
+ if self .client_cache is not None and not self .client_cache_initialized :
697
+ await conn .retry .call_with_retry (
698
+ lambda : self ._send_command_parse_response (
699
+ conn , "CLIENT" , * ("CLIENT" , "TRACKING" , "ON" )
700
+ ),
701
+ lambda error : self ._disconnect_raise (conn , error ),
702
+ )
703
+ self .client_cache_initialized = True
704
+ response = await conn .retry .call_with_retry (
705
+ lambda : self ._send_command_parse_response (
706
+ conn , command_name , * args , ** options
707
+ ),
708
+ lambda error : self ._disconnect_raise (conn , error ),
709
+ )
710
+ self ._add_to_local_cache (args , response , keys )
711
+ return response
712
+ finally :
713
+ if self .single_connection_client :
714
+ self ._single_conn_lock .release ()
715
+ if not self .connection :
716
+ await pool .release (conn )
622
717
623
718
async def parse_response (
624
719
self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -866,7 +961,7 @@ async def connect(self):
866
961
else :
867
962
await self .connection .connect ()
868
963
if self .push_handler_func is not None and not HIREDIS_AVAILABLE :
869
- self .connection ._parser .set_push_handler (self .push_handler_func )
964
+ self .connection ._parser .set_pubsub_push_handler (self .push_handler_func )
870
965
871
966
async def _disconnect_raise_connect (self , conn , error ):
872
967
"""
0 commit comments