7979from synapse .util import unwrapFirstError
8080from synapse .util .async_helpers import ObservableDeferred , delay_cancellation
8181from synapse .util .caches .descriptors import cached , cachedList
82- from synapse .util .caches .lrucache import LruCache
82+ from synapse .util .caches .lrucache import AsyncLruCache
8383from synapse .util .iterutils import batch_iter
8484from synapse .util .metrics import Measure
8585
@@ -238,7 +238,9 @@ def __init__(
238238 5 * 60 * 1000 ,
239239 )
240240
241- self ._get_event_cache : LruCache [Tuple [str ], EventCacheEntry ] = LruCache (
241+ self ._get_event_cache : AsyncLruCache [
242+ Tuple [str ], EventCacheEntry
243+ ] = AsyncLruCache (
242244 cache_name = "*getEvent*" ,
243245 max_size = hs .config .caches .event_cache_size ,
244246 )
@@ -598,7 +600,7 @@ async def _get_events_from_cache_or_db(
598600 Returns:
599601 map from event id to result
600602 """
601- event_entry_map = self ._get_events_from_cache (
603+ event_entry_map = await self ._get_events_from_cache (
602604 event_ids ,
603605 )
604606
@@ -710,12 +712,22 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
710712
711713 return event_entry_map
712714
713- def _invalidate_get_event_cache (self , event_id : str ) -> None :
714- self ._get_event_cache .invalidate ((event_id ,))
715+ async def _invalidate_get_event_cache (self , event_id : str ) -> None :
716+ # First we invalidate the asynchronous cache instance. This may include
717+ # out-of-process caches such as Redis/memcache. Once complete we can
718+ # invalidate any in memory cache. The ordering is important here to
719+ # ensure we don't pull in any remote invalid value after we invalidate
720+ # the in-memory cache.
721+ await self ._get_event_cache .invalidate ((event_id ,))
715722 self ._event_ref .pop (event_id , None )
716723 self ._current_event_fetches .pop (event_id , None )
717724
718- def _get_events_from_cache (
725+ def _invalidate_local_get_event_cache (self , event_id : str ) -> None :
726+ self ._get_event_cache .invalidate_local ((event_id ,))
727+ self ._event_ref .pop (event_id , None )
728+ self ._current_event_fetches .pop (event_id , None )
729+
730+ async def _get_events_from_cache (
719731 self , events : Iterable [str ], update_metrics : bool = True
720732 ) -> Dict [str , EventCacheEntry ]:
721733 """Fetch events from the caches.
@@ -730,7 +742,7 @@ def _get_events_from_cache(
730742
731743 for event_id in events :
732744 # First check if it's in the event cache
733- ret = self ._get_event_cache .get (
745+ ret = await self ._get_event_cache .get (
734746 (event_id ,), None , update_metrics = update_metrics
735747 )
736748 if ret :
@@ -752,7 +764,7 @@ def _get_events_from_cache(
752764
753765 # We add the entry back into the cache as we want to keep
754766 # recently queried events in the cache.
755- self ._get_event_cache .set ((event_id ,), cache_entry )
767+ await self ._get_event_cache .set ((event_id ,), cache_entry )
756768
757769 return event_map
758770
@@ -1129,7 +1141,7 @@ async def _get_events_from_db(
11291141 event = original_ev , redacted_event = redacted_event
11301142 )
11311143
1132- self ._get_event_cache .set ((event_id ,), cache_entry )
1144+ await self ._get_event_cache .set ((event_id ,), cache_entry )
11331145 result_map [event_id ] = cache_entry
11341146
11351147 if not redacted_event :
@@ -1363,7 +1375,9 @@ async def _have_seen_events_dict(
13631375 # if the event cache contains the event, obviously we've seen it.
13641376
13651377 cache_results = {
1366- (rid , eid ) for (rid , eid ) in keys if self ._get_event_cache .contains ((eid ,))
1378+ (rid , eid )
1379+ for (rid , eid ) in keys
1380+ if await self ._get_event_cache .contains ((eid ,))
13671381 }
13681382 results = dict .fromkeys (cache_results , True )
13691383 remaining = [k for k in keys if k not in cache_results ]
0 commit comments