-
Notifications
You must be signed in to change notification settings - Fork 484
/
test_cache.py
190 lines (136 loc) · 4.29 KB
/
test_cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import tempfile
import unittest
import diskcache
import pytest
@pytest.fixture
def refresh_environment():
"""Refresh the test environment.
This deletes any reference to `outlines` in the modules dictionary and unsets the
`OUTLINES_CACHE_DIR` environment variable if set. This is necessary because we
are using a module variable to hold the cache.
"""
import sys
for key in list(sys.modules.keys()):
if "outlines" in key:
del sys.modules[key]
try:
del os.environ["OUTLINES_CACHE_DIR"]
except KeyError:
pass
@pytest.fixture
def test_cache(refresh_environment):
"""Initialize a temporary cache and delete it after the test has run."""
with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
import outlines
memory = outlines.get_cache()
assert memory.directory == tempdir
yield outlines.caching.cache()
memory.clear()
def test_get_cache(test_cache):
import outlines
memory = outlines.get_cache()
assert isinstance(memory, diskcache.Cache)
# If the cache is enabled then the size
# of `store` should not increase the
# second time `f` is called.
store = list()
@test_cache
def f(x):
store.append(1)
return x
f(1)
store_size = len(store)
f(1)
assert len(store) == store_size
f(2)
assert len(store) == store_size + 1
def test_disable_cache(test_cache):
"""Make sure that we can disable the cache."""
import outlines
outlines.disable_cache()
# If the cache is disabled then the size
# of `store` should increase every time
# `f` is called.
store = list()
@test_cache
def f(x):
store.append(1)
return x
f(1)
store_size = len(store)
f(1)
assert len(store) == store_size + 1
def test_clear_cache(test_cache):
"""Make sure that we can clear the cache."""
import outlines
store = list()
@test_cache
def f(x):
store.append(1)
return x
# The size of `store` does not increase since
# `f` is cached after the first run.
f(1)
store_size = len(store)
f(1)
assert len(store) == store_size
# The size of `store` should increase if we call `f`
# after clearing the cache.
outlines.clear_cache()
f(1)
assert len(store) == store_size + 1
def test_version_upgrade_cache_invalidate(test_cache, mocker):
"""Ensure we can change the signature of a cached function if we upgrade the version"""
import outlines.caching
def simulate_restart_outlines():
# clearing in-memory lru_cache which returns the diskcache in
# order to simulate a reload, we're not clearing the diskcache itself
outlines.caching.get_cache.cache_clear()
mocker.patch("outlines._version.__version__", new="0.0.0")
simulate_restart_outlines()
# initialize cache with signature of Tuple-of-3
@test_cache
def foo():
return (1, 2, 3)
a, b, c = foo()
# "restart" outlines without upgrading version
simulate_restart_outlines()
# change signature to Tuple-of-2
@test_cache
def foo():
return (1, 2)
# assert without version upgrade, old, bad cache is used
with pytest.raises(ValueError):
a, b = foo()
# "restart" outlines WITH version upgrade
mocker.patch("outlines._version.__version__", new="0.0.1")
simulate_restart_outlines()
# change signature to Tuple-of-2
@test_cache
def foo():
return (1, 2)
# assert with version upgrade, old cache is invalidated and new cache is used
a, b = foo()
def test_cache_disabled_decorator(test_cache):
"""Ensure cache can be disabled in a local scope"""
from outlines.caching import cache_disabled
mock = unittest.mock.MagicMock()
@test_cache
def fn():
mock()
return 1
# first call isn't cached
fn()
assert mock.call_count == 1
# second call doesn't run fn, uses cache
fn()
assert mock.call_count == 1
# cache_disabled decorator disables cache within scope
with cache_disabled():
fn()
assert mock.call_count == 2 # called once in cache_disabled scope
# scope has exited, cache is enabled again
fn()
assert mock.call_count == 2