19
19
import threading
20
20
from collections import namedtuple
21
21
22
- import six
23
- from six import itervalues , string_types
22
+ from six import itervalues
24
23
25
24
from prometheus_client import Gauge
26
25
32
31
from synapse .util .caches import get_cache_factor_for
33
32
from synapse .util .caches .lrucache import LruCache
34
33
from synapse .util .caches .treecache import TreeCache , iterate_tree_cache_entry
35
- from synapse .util .stringutils import to_ascii
36
34
37
35
from . import register_cache
38
36
@@ -124,7 +122,7 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
124
122
update_metrics (bool): whether to update the cache hit rate metrics
125
123
126
124
Returns:
127
- Either a Deferred or the raw result
125
+ Either an ObservableDeferred or the raw result
128
126
"""
129
127
callbacks = [callback ] if callback else []
130
128
val = self ._pending_deferred_cache .get (key , _CacheSentinel )
@@ -148,40 +146,63 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
148
146
return default
149
147
150
148
def set (self , key , value , callback = None ):
149
+ if not isinstance (value , defer .Deferred ):
150
+ raise TypeError ("not a Deferred" )
151
+
151
152
callbacks = [callback ] if callback else []
152
153
self .check_thread ()
153
- entry = CacheEntry (deferred = value , callbacks = callbacks )
154
+ observable = ObservableDeferred (value , consumeErrors = True )
155
+ observer = defer .maybeDeferred (observable .observe )
156
+ entry = CacheEntry (deferred = observable , callbacks = callbacks )
154
157
155
158
existing_entry = self ._pending_deferred_cache .pop (key , None )
156
159
if existing_entry :
157
160
existing_entry .invalidate ()
158
161
159
162
self ._pending_deferred_cache [key ] = entry
160
163
161
- def shuffle (result ):
164
+ def compare_and_pop ():
165
+ """Check if our entry is still the one in _pending_deferred_cache, and
166
+ if so, pop it.
167
+
168
+ Returns true if the entries matched.
169
+ """
162
170
existing_entry = self ._pending_deferred_cache .pop (key , None )
163
171
if existing_entry is entry :
172
+ return True
173
+
174
+ # oops, the _pending_deferred_cache has been updated since
175
+ # we started our query, so we are out of date.
176
+ #
177
+ # Better put back whatever we took out. (We do it this way
178
+ # round, rather than peeking into the _pending_deferred_cache
179
+ # and then removing on a match, to make the common case faster)
180
+ if existing_entry is not None :
181
+ self ._pending_deferred_cache [key ] = existing_entry
182
+
183
+ return False
184
+
185
+ def cb (result ):
186
+ if compare_and_pop ():
164
187
self .cache .set (key , result , entry .callbacks )
165
188
else :
166
- # oops, the _pending_deferred_cache has been updated since
167
- # we started our query, so we are out of date.
168
- #
169
- # Better put back whatever we took out. (We do it this way
170
- # round, rather than peeking into the _pending_deferred_cache
171
- # and then removing on a match, to make the common case faster)
172
- if existing_entry is not None :
173
- self ._pending_deferred_cache [key ] = existing_entry
174
-
175
189
# we're not going to put this entry into the cache, so need
176
190
# to make sure that the invalidation callbacks are called.
177
191
# That was probably done when _pending_deferred_cache was
178
192
# updated, but it's possible that `set` was called without
179
193
# `invalidate` being previously called, in which case it may
180
194
# not have been. Either way, let's double-check now.
181
195
entry .invalidate ()
182
- return result
183
196
184
- entry .deferred .addCallback (shuffle )
197
+ def eb (_fail ):
198
+ compare_and_pop ()
199
+ entry .invalidate ()
200
+
201
+ # once the deferred completes, we can move the entry from the
202
+ # _pending_deferred_cache to the real cache.
203
+ #
204
+ observer .addCallbacks (cb , eb )
205
+ return observable
185
206
186
207
def prefill (self , key , value , callback = None ):
187
208
callbacks = [callback ] if callback else []
@@ -414,20 +435,10 @@ def onErr(f):
414
435
415
436
ret .addErrback (onErr )
416
437
417
- # If our cache_key is a string on py2, try to convert to ascii
418
- # to save a bit of space in large caches. Py3 does this
419
- # internally automatically.
420
- if six .PY2 and isinstance (cache_key , string_types ):
421
- cache_key = to_ascii (cache_key )
422
-
423
- result_d = ObservableDeferred (ret , consumeErrors = True )
424
- cache .set (cache_key , result_d , callback = invalidate_callback )
438
+ result_d = cache .set (cache_key , ret , callback = invalidate_callback )
425
439
observer = result_d .observe ()
426
440
427
- if isinstance (observer , defer .Deferred ):
428
- return make_deferred_yieldable (observer )
429
- else :
430
- return observer
441
+ return make_deferred_yieldable (observer )
431
442
432
443
if self .num_args == 1 :
433
444
wrapped .invalidate = lambda key : cache .invalidate (key [0 ])
@@ -543,16 +554,15 @@ def arg_to_cache_key(arg):
543
554
missing .add (arg )
544
555
545
556
if missing :
546
- # we need an observable deferred for each entry in the list,
557
+ # we need a deferred for each entry in the list,
547
558
# which we put in the cache. Each deferred resolves with the
548
559
# relevant result for that key.
549
560
deferreds_map = {}
550
561
for arg in missing :
551
562
deferred = defer .Deferred ()
552
563
deferreds_map [arg ] = deferred
553
564
key = arg_to_cache_key (arg )
554
- observable = ObservableDeferred (deferred )
555
- cache .set (key , observable , callback = invalidate_callback )
565
+ cache .set (key , deferred , callback = invalidate_callback )
556
566
557
567
def complete_all (res ):
558
568
# the wrapped function has completed. It returns a
0 commit comments