|
15 | 15 | # limitations under the License. |
16 | 16 |
|
17 | 17 |
|
18 | | -from mock import Mock |
19 | | - |
20 | | -from twisted.internet import defer |
21 | | - |
22 | | -from synapse.util.async_helpers import ObservableDeferred |
23 | | -from synapse.util.caches.descriptors import cached |
24 | | - |
25 | 18 | from tests import unittest |
26 | 19 |
|
27 | 20 |
|
28 | | -class CacheDecoratorTestCase(unittest.HomeserverTestCase): |
29 | | - @defer.inlineCallbacks |
30 | | - def test_passthrough(self): |
31 | | - class A: |
32 | | - @cached() |
33 | | - def func(self, key): |
34 | | - return key |
35 | | - |
36 | | - a = A() |
37 | | - |
38 | | - self.assertEquals((yield a.func("foo")), "foo") |
39 | | - self.assertEquals((yield a.func("bar")), "bar") |
40 | | - |
41 | | - @defer.inlineCallbacks |
42 | | - def test_hit(self): |
43 | | - callcount = [0] |
44 | | - |
45 | | - class A: |
46 | | - @cached() |
47 | | - def func(self, key): |
48 | | - callcount[0] += 1 |
49 | | - return key |
50 | | - |
51 | | - a = A() |
52 | | - yield a.func("foo") |
53 | | - |
54 | | - self.assertEquals(callcount[0], 1) |
55 | | - |
56 | | - self.assertEquals((yield a.func("foo")), "foo") |
57 | | - self.assertEquals(callcount[0], 1) |
58 | | - |
59 | | - @defer.inlineCallbacks |
60 | | - def test_invalidate(self): |
61 | | - callcount = [0] |
62 | | - |
63 | | - class A: |
64 | | - @cached() |
65 | | - def func(self, key): |
66 | | - callcount[0] += 1 |
67 | | - return key |
68 | | - |
69 | | - a = A() |
70 | | - yield a.func("foo") |
71 | | - |
72 | | - self.assertEquals(callcount[0], 1) |
73 | | - |
74 | | - a.func.invalidate(("foo",)) |
75 | | - |
76 | | - yield a.func("foo") |
77 | | - |
78 | | - self.assertEquals(callcount[0], 2) |
79 | | - |
80 | | - def test_invalidate_missing(self): |
81 | | - class A: |
82 | | - @cached() |
83 | | - def func(self, key): |
84 | | - return key |
85 | | - |
86 | | - A().func.invalidate(("what",)) |
87 | | - |
88 | | - @defer.inlineCallbacks |
89 | | - def test_max_entries(self): |
90 | | - callcount = [0] |
91 | | - |
92 | | - class A: |
93 | | - @cached(max_entries=10) |
94 | | - def func(self, key): |
95 | | - callcount[0] += 1 |
96 | | - return key |
97 | | - |
98 | | - a = A() |
99 | | - |
100 | | - for k in range(0, 12): |
101 | | - yield a.func(k) |
102 | | - |
103 | | - self.assertEquals(callcount[0], 12) |
104 | | - |
105 | | - # There must have been at least 2 evictions, meaning if we calculate |
106 | | - # all 12 values again, we must get called at least 2 more times |
107 | | - for k in range(0, 12): |
108 | | - yield a.func(k) |
109 | | - |
110 | | - self.assertTrue( |
111 | | - callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) |
112 | | - ) |
113 | | - |
114 | | - def test_prefill(self): |
115 | | - callcount = [0] |
116 | | - |
117 | | - d = defer.succeed(123) |
118 | | - |
119 | | - class A: |
120 | | - @cached() |
121 | | - def func(self, key): |
122 | | - callcount[0] += 1 |
123 | | - return d |
124 | | - |
125 | | - a = A() |
126 | | - |
127 | | - a.func.prefill(("foo",), ObservableDeferred(d)) |
128 | | - |
129 | | - self.assertEquals(a.func("foo").result, d.result) |
130 | | - self.assertEquals(callcount[0], 0) |
131 | | - |
132 | | - @defer.inlineCallbacks |
133 | | - def test_invalidate_context(self): |
134 | | - callcount = [0] |
135 | | - callcount2 = [0] |
136 | | - |
137 | | - class A: |
138 | | - @cached() |
139 | | - def func(self, key): |
140 | | - callcount[0] += 1 |
141 | | - return key |
142 | | - |
143 | | - @cached(cache_context=True) |
144 | | - def func2(self, key, cache_context): |
145 | | - callcount2[0] += 1 |
146 | | - return self.func(key, on_invalidate=cache_context.invalidate) |
147 | | - |
148 | | - a = A() |
149 | | - yield a.func2("foo") |
150 | | - |
151 | | - self.assertEquals(callcount[0], 1) |
152 | | - self.assertEquals(callcount2[0], 1) |
153 | | - |
154 | | - a.func.invalidate(("foo",)) |
155 | | - yield a.func("foo") |
156 | | - |
157 | | - self.assertEquals(callcount[0], 2) |
158 | | - self.assertEquals(callcount2[0], 1) |
159 | | - |
160 | | - yield a.func2("foo") |
161 | | - |
162 | | - self.assertEquals(callcount[0], 2) |
163 | | - self.assertEquals(callcount2[0], 2) |
164 | | - |
165 | | - @defer.inlineCallbacks |
166 | | - def test_eviction_context(self): |
167 | | - callcount = [0] |
168 | | - callcount2 = [0] |
169 | | - |
170 | | - class A: |
171 | | - @cached(max_entries=2) |
172 | | - def func(self, key): |
173 | | - callcount[0] += 1 |
174 | | - return key |
175 | | - |
176 | | - @cached(cache_context=True) |
177 | | - def func2(self, key, cache_context): |
178 | | - callcount2[0] += 1 |
179 | | - return self.func(key, on_invalidate=cache_context.invalidate) |
180 | | - |
181 | | - a = A() |
182 | | - yield a.func2("foo") |
183 | | - yield a.func2("foo2") |
184 | | - |
185 | | - self.assertEquals(callcount[0], 2) |
186 | | - self.assertEquals(callcount2[0], 2) |
187 | | - |
188 | | - yield a.func2("foo") |
189 | | - self.assertEquals(callcount[0], 2) |
190 | | - self.assertEquals(callcount2[0], 2) |
191 | | - |
192 | | - yield a.func("foo3") |
193 | | - |
194 | | - self.assertEquals(callcount[0], 3) |
195 | | - self.assertEquals(callcount2[0], 2) |
196 | | - |
197 | | - yield a.func2("foo") |
198 | | - |
199 | | - self.assertEquals(callcount[0], 4) |
200 | | - self.assertEquals(callcount2[0], 3) |
201 | | - |
202 | | - @defer.inlineCallbacks |
203 | | - def test_double_get(self): |
204 | | - callcount = [0] |
205 | | - callcount2 = [0] |
206 | | - |
207 | | - class A: |
208 | | - @cached() |
209 | | - def func(self, key): |
210 | | - callcount[0] += 1 |
211 | | - return key |
212 | | - |
213 | | - @cached(cache_context=True) |
214 | | - def func2(self, key, cache_context): |
215 | | - callcount2[0] += 1 |
216 | | - return self.func(key, on_invalidate=cache_context.invalidate) |
217 | | - |
218 | | - a = A() |
219 | | - a.func2.cache.cache = Mock(wraps=a.func2.cache.cache) |
220 | | - |
221 | | - yield a.func2("foo") |
222 | | - |
223 | | - self.assertEquals(callcount[0], 1) |
224 | | - self.assertEquals(callcount2[0], 1) |
225 | | - |
226 | | - a.func2.invalidate(("foo",)) |
227 | | - self.assertEquals(a.func2.cache.cache.pop.call_count, 1) |
228 | | - |
229 | | - yield a.func2("foo") |
230 | | - a.func2.invalidate(("foo",)) |
231 | | - self.assertEquals(a.func2.cache.cache.pop.call_count, 2) |
232 | | - |
233 | | - self.assertEquals(callcount[0], 1) |
234 | | - self.assertEquals(callcount2[0], 2) |
235 | | - |
236 | | - a.func.invalidate(("foo",)) |
237 | | - self.assertEquals(a.func2.cache.cache.pop.call_count, 3) |
238 | | - yield a.func("foo") |
239 | | - |
240 | | - self.assertEquals(callcount[0], 2) |
241 | | - self.assertEquals(callcount2[0], 2) |
242 | | - |
243 | | - yield a.func2("foo") |
244 | | - |
245 | | - self.assertEquals(callcount[0], 2) |
246 | | - self.assertEquals(callcount2[0], 3) |
247 | | - |
248 | | - |
249 | 21 | class UpsertManyTests(unittest.HomeserverTestCase): |
250 | 22 | def prepare(self, reactor, clock, hs): |
251 | 23 | self.storage = hs.get_datastore() |
|
0 commit comments