14
14
15
15
import logging
16
16
import threading
17
+ import weakref
17
18
from enum import Enum , auto
18
19
from typing import (
19
20
TYPE_CHECKING ,
23
24
Dict ,
24
25
Iterable ,
25
26
List ,
27
+ MutableMapping ,
26
28
Optional ,
27
29
Set ,
28
30
Tuple ,
@@ -248,6 +250,12 @@ def __init__(
248
250
str , ObservableDeferred [Dict [str , EventCacheEntry ]]
249
251
] = {}
250
252
253
+ # We keep track of the events we have currently loaded in memory so that
254
+ # we can reuse them even if they've been evicted from the cache. We only
255
+ # track events that don't need redacting in here (as then we don't need
256
+ # to track redaction status).
257
+ self ._event_ref : MutableMapping [str , EventBase ] = weakref .WeakValueDictionary ()
258
+
251
259
self ._event_fetch_lock = threading .Condition ()
252
260
self ._event_fetch_list : List [
253
261
Tuple [Iterable [str ], "defer.Deferred[Dict[str, _EventRow]]" ]
@@ -723,6 +731,8 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
723
731
724
732
def _invalidate_get_event_cache (self , event_id : str ) -> None :
725
733
self ._get_event_cache .invalidate ((event_id ,))
734
+ self ._event_ref .pop (event_id , None )
735
+ self ._current_event_fetches .pop (event_id , None )
726
736
727
737
def _get_events_from_cache (
728
738
self , events : Iterable [str ], update_metrics : bool = True
@@ -738,13 +748,30 @@ def _get_events_from_cache(
738
748
event_map = {}
739
749
740
750
for event_id in events :
751
+ # First check if it's in the event cache
741
752
ret = self ._get_event_cache .get (
742
753
(event_id ,), None , update_metrics = update_metrics
743
754
)
744
- if not ret :
755
+ if ret :
756
+ event_map [event_id ] = ret
745
757
continue
746
758
747
- event_map [event_id ] = ret
759
+ # Otherwise check if we still have the event in memory.
760
+ event = self ._event_ref .get (event_id )
761
+ if event :
762
+ # Reconstruct an event cache entry
763
+
764
+ cache_entry = EventCacheEntry (
765
+ event = event ,
766
+ # We don't cache weakrefs to redacted events, so we know
767
+ # this is None.
768
+ redacted_event = None ,
769
+ )
770
+ event_map [event_id ] = cache_entry
771
+
772
+ # We add the entry back into the cache as we want to keep
773
+ # recently queried events in the cache.
774
+ self ._get_event_cache .set ((event_id ,), cache_entry )
748
775
749
776
return event_map
750
777
@@ -1124,6 +1151,10 @@ async def _get_events_from_db(
1124
1151
self ._get_event_cache .set ((event_id ,), cache_entry )
1125
1152
result_map [event_id ] = cache_entry
1126
1153
1154
+ if not redacted_event :
1155
+ # We only cache references to unredacted events.
1156
+ self ._event_ref [event_id ] = original_ev
1157
+
1127
1158
return result_map
1128
1159
1129
1160
async def _enqueue_events (self , events : Collection [str ]) -> Dict [str , _EventRow ]:
0 commit comments