Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit c544308

Browse files
committed
Fix some error cases in the caching layer. (#5749)
2 parents f7bf143 + 618bd1e commit c544308

File tree

3 files changed

+130
-35
lines changed

3 files changed

+130
-35
lines changed

changelog.d/5749.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix some error cases in the caching layer.

synapse/util/caches/descriptors.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import threading
2020
from collections import namedtuple
2121

22-
import six
23-
from six import itervalues, string_types
22+
from six import itervalues
2423

2524
from prometheus_client import Gauge
2625

@@ -32,7 +31,6 @@
3231
from synapse.util.caches import get_cache_factor_for
3332
from synapse.util.caches.lrucache import LruCache
3433
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
35-
from synapse.util.stringutils import to_ascii
3634

3735
from . import register_cache
3836

@@ -124,7 +122,7 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
124122
update_metrics (bool): whether to update the cache hit rate metrics
125123
126124
Returns:
127-
Either a Deferred or the raw result
125+
Either an ObservableDeferred or the raw result
128126
"""
129127
callbacks = [callback] if callback else []
130128
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -148,40 +146,63 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
148146
return default
149147

150148
def set(self, key, value, callback=None):
149+
if not isinstance(value, defer.Deferred):
150+
raise TypeError("not a Deferred")
151+
151152
callbacks = [callback] if callback else []
152153
self.check_thread()
153-
entry = CacheEntry(deferred=value, callbacks=callbacks)
154+
observable = ObservableDeferred(value, consumeErrors=True)
155+
observer = defer.maybeDeferred(observable.observe)
156+
entry = CacheEntry(deferred=observable, callbacks=callbacks)
154157

155158
existing_entry = self._pending_deferred_cache.pop(key, None)
156159
if existing_entry:
157160
existing_entry.invalidate()
158161

159162
self._pending_deferred_cache[key] = entry
160163

161-
def shuffle(result):
164+
def compare_and_pop():
165+
"""Check if our entry is still the one in _pending_deferred_cache, and
166+
if so, pop it.
167+
168+
Returns true if the entries matched.
169+
"""
162170
existing_entry = self._pending_deferred_cache.pop(key, None)
163171
if existing_entry is entry:
172+
return True
173+
174+
# oops, the _pending_deferred_cache has been updated since
175+
# we started our query, so we are out of date.
176+
#
177+
# Better put back whatever we took out. (We do it this way
178+
# round, rather than peeking into the _pending_deferred_cache
179+
# and then removing on a match, to make the common case faster)
180+
if existing_entry is not None:
181+
self._pending_deferred_cache[key] = existing_entry
182+
183+
return False
184+
185+
def cb(result):
186+
if compare_and_pop():
164187
self.cache.set(key, result, entry.callbacks)
165188
else:
166-
# oops, the _pending_deferred_cache has been updated since
167-
# we started our query, so we are out of date.
168-
#
169-
# Better put back whatever we took out. (We do it this way
170-
# round, rather than peeking into the _pending_deferred_cache
171-
# and then removing on a match, to make the common case faster)
172-
if existing_entry is not None:
173-
self._pending_deferred_cache[key] = existing_entry
174-
175189
# we're not going to put this entry into the cache, so need
176190
# to make sure that the invalidation callbacks are called.
177191
# That was probably done when _pending_deferred_cache was
178192
# updated, but it's possible that `set` was called without
179193
# `invalidate` being previously called, in which case it may
180194
# not have been. Either way, let's double-check now.
181195
entry.invalidate()
182-
return result
183196

184-
entry.deferred.addCallback(shuffle)
197+
def eb(_fail):
198+
compare_and_pop()
199+
entry.invalidate()
200+
201+
# once the deferred completes, we can move the entry from the
202+
# _pending_deferred_cache to the real cache.
203+
#
204+
observer.addCallbacks(cb, eb)
205+
return observable
185206

186207
def prefill(self, key, value, callback=None):
187208
callbacks = [callback] if callback else []
@@ -414,20 +435,10 @@ def onErr(f):
414435

415436
ret.addErrback(onErr)
416437

417-
# If our cache_key is a string on py2, try to convert to ascii
418-
# to save a bit of space in large caches. Py3 does this
419-
# internally automatically.
420-
if six.PY2 and isinstance(cache_key, string_types):
421-
cache_key = to_ascii(cache_key)
422-
423-
result_d = ObservableDeferred(ret, consumeErrors=True)
424-
cache.set(cache_key, result_d, callback=invalidate_callback)
438+
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
425439
observer = result_d.observe()
426440

427-
if isinstance(observer, defer.Deferred):
428-
return make_deferred_yieldable(observer)
429-
else:
430-
return observer
441+
return make_deferred_yieldable(observer)
431442

432443
if self.num_args == 1:
433444
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -543,16 +554,15 @@ def arg_to_cache_key(arg):
543554
missing.add(arg)
544555

545556
if missing:
546-
# we need an observable deferred for each entry in the list,
557+
# we need a deferred for each entry in the list,
547558
# which we put in the cache. Each deferred resolves with the
548559
# relevant result for that key.
549560
deferreds_map = {}
550561
for arg in missing:
551562
deferred = defer.Deferred()
552563
deferreds_map[arg] = deferred
553564
key = arg_to_cache_key(arg)
554-
observable = ObservableDeferred(deferred)
555-
cache.set(key, observable, callback=invalidate_callback)
565+
cache.set(key, deferred, callback=invalidate_callback)
556566

557567
def complete_all(res):
558568
# the wrapped function has completed. It returns a

tests/util/caches/test_descriptors.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
make_deferred_yieldable,
2828
)
2929
from synapse.util.caches import descriptors
30+
from synapse.util.caches.descriptors import cached
3031

3132
from tests import unittest
3233

@@ -55,12 +56,15 @@ def record_callback(idx):
5556
d2 = defer.Deferred()
5657
cache.set("key2", d2, partial(record_callback, 1))
5758

58-
# lookup should return the deferreds
59-
self.assertIs(cache.get("key1"), d1)
60-
self.assertIs(cache.get("key2"), d2)
59+
# lookup should return observable deferreds
60+
self.assertFalse(cache.get("key1").has_called())
61+
self.assertFalse(cache.get("key2").has_called())
6162

6263
# let one of the lookups complete
6364
d2.callback("result2")
65+
66+
# for now at least, the cache will return real results rather than an
67+
# observabledeferred
6468
self.assertEqual(cache.get("key2"), "result2")
6569

6670
# now do the invalidation
@@ -146,6 +150,28 @@ def fn(self, arg1, arg2):
146150
self.assertEqual(r, "chips")
147151
obj.mock.assert_not_called()
148152

153+
def test_cache_with_sync_exception(self):
154+
"""If the wrapped function throws synchronously, things should continue to work
155+
"""
156+
157+
class Cls(object):
158+
@cached()
159+
def fn(self, arg1):
160+
raise SynapseError(100, "mai spoon iz too big!!1")
161+
162+
obj = Cls()
163+
164+
# this should fail immediately
165+
d = obj.fn(1)
166+
self.failureResultOf(d, SynapseError)
167+
168+
# ... leaving the cache empty
169+
self.assertEqual(len(obj.fn.cache.cache), 0)
170+
171+
# and a second call should result in a second exception
172+
d = obj.fn(1)
173+
self.failureResultOf(d, SynapseError)
174+
149175
def test_cache_logcontexts(self):
150176
"""Check that logcontexts are set and restored correctly when
151177
using the cache."""
@@ -222,6 +248,9 @@ def do_lookup():
222248

223249
self.assertEqual(LoggingContext.current_context(), c1)
224250

251+
# the cache should now be empty
252+
self.assertEqual(len(obj.fn.cache.cache), 0)
253+
225254
obj = Cls()
226255

227256
# set off a deferred which will do a cache lookup
@@ -268,6 +297,61 @@ def fn(self, arg1, arg2=2, arg3=3):
268297
self.assertEqual(r, "chips")
269298
obj.mock.assert_not_called()
270299

300+
def test_cache_iterable(self):
301+
class Cls(object):
302+
def __init__(self):
303+
self.mock = mock.Mock()
304+
305+
@descriptors.cached(iterable=True)
306+
def fn(self, arg1, arg2):
307+
return self.mock(arg1, arg2)
308+
309+
obj = Cls()
310+
311+
obj.mock.return_value = ["spam", "eggs"]
312+
r = obj.fn(1, 2)
313+
self.assertEqual(r, ["spam", "eggs"])
314+
obj.mock.assert_called_once_with(1, 2)
315+
obj.mock.reset_mock()
316+
317+
# a call with different params should call the mock again
318+
obj.mock.return_value = ["chips"]
319+
r = obj.fn(1, 3)
320+
self.assertEqual(r, ["chips"])
321+
obj.mock.assert_called_once_with(1, 3)
322+
obj.mock.reset_mock()
323+
324+
# the two values should now be cached
325+
self.assertEqual(len(obj.fn.cache.cache), 3)
326+
327+
r = obj.fn(1, 2)
328+
self.assertEqual(r, ["spam", "eggs"])
329+
r = obj.fn(1, 3)
330+
self.assertEqual(r, ["chips"])
331+
obj.mock.assert_not_called()
332+
333+
def test_cache_iterable_with_sync_exception(self):
334+
"""If the wrapped function throws synchronously, things should continue to work
335+
"""
336+
337+
class Cls(object):
338+
@descriptors.cached(iterable=True)
339+
def fn(self, arg1):
340+
raise SynapseError(100, "mai spoon iz too big!!1")
341+
342+
obj = Cls()
343+
344+
# this should fail immediately
345+
d = obj.fn(1)
346+
self.failureResultOf(d, SynapseError)
347+
348+
# ... leaving the cache empty
349+
self.assertEqual(len(obj.fn.cache.cache), 0)
350+
351+
# and a second call should result in a second exception
352+
d = obj.fn(1)
353+
self.failureResultOf(d, SynapseError)
354+
271355

272356
class CachedListDescriptorTestCase(unittest.TestCase):
273357
@defer.inlineCallbacks

0 commit comments

Comments
 (0)