Skip to content

Commit 6dffdeb

Browse files
authored
feature/add cache to reduce disk reading frequency (#169)
1 parent 4f41b19 commit 6dffdeb

File tree

7 files changed

+128
-27
lines changed

7 files changed

+128
-27
lines changed

demo/vdl_scratch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import random
55
import subprocess
66

7+
78
import numpy as np
89
from PIL import Image
910
from scipy.stats import norm
1011
from visualdl import ROOT, LogWriter
12+
from visualdl.server.log import logger as log
1113

1214
logdir = './scratch_log'
1315

@@ -92,3 +94,20 @@
9294
data = np.random.random(shape).flatten()
9395
image0.add_sample(shape, list(data))
9496
image0.finish_sampling()
97+
98+
def download_graph_image():
99+
'''
100+
This is a scratch demo, it do not generate a ONNX proto, but just download an image
101+
that generated before to show how the graph frontend works.
102+
103+
For real cases, just refer to README.
104+
'''
105+
import urllib
106+
image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true"
107+
log.warning('download graph demo from {}'.format(image_url))
108+
graph_image = urllib.urlopen(image_url).read()
109+
with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f:
110+
f.write(graph_image)
111+
log.warning('graph ready!')
112+
113+
download_graph_image()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def readlines(name):
2626
VERSION_NUMBER = read('VERSION_NUMBER')
2727
LICENSE = readlines('LICENSE')[0].strip()
2828

29+
# use memcache to reduce disk read frequency.
2930
install_requires = ['Flask', 'numpy', 'Pillow', 'protobuf', 'scipy']
3031
execute_requires = ['npm', 'node', 'bash']
3132

visualdl/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ function(py_test TARGET_NAME)
2525
endfunction()
2626

2727
py_test(test_summary SRCS test_storage.py)
28+
py_test(test_cache SRCS cache.py)

visualdl/python/cache.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import time
2+
3+
4+
class MemCache(object):
5+
class Record:
6+
def __init__(self, value):
7+
self.time = time.time()
8+
self.value = value
9+
10+
def clear(self):
11+
self.value = None
12+
13+
def expired(self, timeout):
14+
return timeout > 0 and time.time() - self.time >= timeout
15+
'''
16+
A global dict to help cache some temporary data.
17+
'''
18+
def __init__(self, timeout=-1):
19+
self._timeout = timeout
20+
self._data = {}
21+
22+
def set(self, key, value):
23+
self._data[key] = MemCache.Record(value)
24+
25+
def get(self, key):
26+
rcd = self._data.get(key, None)
27+
if not rcd: return None
28+
# do not delete the key to accelerate speed
29+
if rcd.expired(self._timeout):
30+
rcd.clear()
31+
return None
32+
return rcd.value
33+
34+
if __name__ == '__main__':
35+
import unittest
36+
37+
class TestMemCacheTest(unittest.TestCase):
38+
def setUp(self):
39+
self.cache = MemCache(timeout=1)
40+
41+
def expire(self):
42+
self.cache.set("message", "hello")
43+
self.assertFalse(self.cache.expired(1))
44+
time.sleep(4)
45+
self.assertTrue(self.cache.expired(1))
46+
47+
def test_have_key(self):
48+
self.cache.set('message', 'hello')
49+
self.assertTrue(self.cache.get('message'))
50+
time.sleep(1.1)
51+
self.assertFalse(self.cache.get('message'))
52+
self.assertTrue(self.cache.get("message") is None)
53+
54+
unittest.main()

visualdl/python/test_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import numpy as np
66
from PIL import Image
7+
from visualdl import LogReader, LogWriter
78

89
pprint.pprint(sys.path)
910

10-
from visualdl import LogWriter, LogReader
1111

1212

1313
class StorageTest(unittest.TestCase):

visualdl/server/lib.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pprint
21
import re
32
import sys
43
import time
@@ -7,6 +6,7 @@
76

87
import numpy as np
98
from PIL import Image
9+
1010
from log import logger
1111

1212

@@ -90,7 +90,6 @@ def get_image_tags(storage):
9090

9191

9292
def get_image_tag_steps(storage, mode, tag):
93-
print 'image_tag_steps,mode,tag:', mode, tag
9493
# remove suffix '/x'
9594
res = re.search(r".*/([0-9]+$)", tag)
9695
sample_index = 0
@@ -211,3 +210,14 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
211210
error_info = '\n'.join(map(str, sys.exc_info()))
212211
logger.error("Unexpected error: %s" % error_info)
213212
time.sleep(time2sleep)
213+
214+
def cache_get(cache):
215+
def _handler(key, func, *args, **kwargs):
216+
data = cache.get(key)
217+
if data is None:
218+
logger.warning('update cache %s' % key)
219+
data = func(*args, **kwargs)
220+
cache.set(key, data)
221+
return data
222+
return data
223+
return _handler

visualdl/server/visualDL

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ from visualdl.server import lib
1717
from visualdl.server.log import logger
1818
from visualdl.server.mock import data as mock_data
1919
from visualdl.server.mock import data as mock_tags
20+
from visualdl.python.cache import MemCache
2021
from visualdl.python.storage import (LogWriter, LogReader)
2122

2223
app = Flask(__name__, static_url_path="")
@@ -33,7 +34,7 @@ def try_call(function, *args, **kwargs):
3334
res = lib.retry(error_retry_times, function, error_sleep_time, *args,
3435
**kwargs)
3536
if not res:
36-
raise exceptions.IOError("server IO error, will retry latter.")
37+
logger.error("server temporary error, will retry latter.")
3738
return res
3839

3940

@@ -70,6 +71,14 @@ def parse_args():
7071
action="store",
7172
dest="logdir",
7273
help="log file directory")
74+
parser.add_argument(
75+
"--cache_timeout",
76+
action="store",
77+
dest="cache_timeout",
78+
type=float,
79+
default=20,
80+
help="memory cache timeout duration in seconds, default 20",
81+
)
7382
args = parser.parse_args()
7483
if not args.logdir:
7584
parser.print_help()
@@ -86,8 +95,11 @@ log_reader = LogReader(args.logdir)
8695

8796
# mannully put graph's image on this path also works.
8897
graph_image_path = os.path.join(args.logdir, 'graph.jpg')
98+
# use a memory cache to reduce disk reading frequency.
99+
CACHE = MemCache(timeout=args.cache_timeout)
100+
cache_get = lib.cache_get(CACHE)
101+
89102

90-
# return data
91103
# status, msg, data
92104
def gen_result(status, msg, data):
93105
"""
@@ -126,52 +138,54 @@ def logdir():
126138

127139
@app.route('/data/runs')
128140
def runs():
129-
result = gen_result(0, "", lib.get_modes(log_reader))
141+
data = cache_get('/data/runs', lib.get_modes, log_reader)
142+
result = gen_result(0, "", data)
130143
return Response(json.dumps(result), mimetype='application/json')
131144

132145

133146
@app.route("/data/plugin/scalars/tags")
134147
def scalar_tags():
135-
mode = request.args.get('mode')
136-
is_debug = bool(request.args.get('debug'))
137-
result = try_call(lib.get_scalar_tags, log_reader)
138-
result = gen_result(0, "", result)
148+
data = cache_get("/data/plugin/scalars/tags", try_call,
149+
lib.get_scalar_tags, log_reader)
150+
result = gen_result(0, "", data)
139151
return Response(json.dumps(result), mimetype='application/json')
140152

141153

142154
@app.route("/data/plugin/images/tags")
143155
def image_tags():
144-
mode = request.args.get('run')
145-
result = try_call(lib.get_image_tags, log_reader)
146-
result = gen_result(0, "", result)
156+
data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags,
157+
log_reader)
158+
result = gen_result(0, "", data)
147159
return Response(json.dumps(result), mimetype='application/json')
148160

149161

150162
@app.route("/data/plugin/histograms/tags")
151163
def histogram_tags():
152-
mode = request.args.get('run')
153-
# hack to avlid IO conflicts
154-
result = try_call(lib.get_histogram_tags, log_reader)
155-
result = gen_result(0, "", result)
164+
data = cache_get("/data/plugin/histograms/tags", try_call,
165+
lib.get_histogram_tags, log_reader)
166+
result = gen_result(0, "", data)
156167
return Response(json.dumps(result), mimetype='application/json')
157168

158169

159170
@app.route('/data/plugin/scalars/scalars')
160171
def scalars():
161172
run = request.args.get('run')
162173
tag = request.args.get('tag')
163-
result = try_call(lib.get_scalar, log_reader, run, tag)
164-
result = gen_result(0, "", result)
174+
key = os.path.join('/data/plugin/scalars/scalars', run, tag)
175+
data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag)
176+
result = gen_result(0, "", data)
165177
return Response(json.dumps(result), mimetype='application/json')
166178

167179

168180
@app.route('/data/plugin/images/images')
169181
def images():
170182
mode = request.args.get('run')
171183
tag = request.args.get('tag')
184+
key = os.path.join('/data/plugin/images/images', mode, tag)
172185

173-
result = try_call(lib.get_image_tag_steps, log_reader, mode, tag)
174-
result = gen_result(0, "", result)
186+
data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode,
187+
tag)
188+
result = gen_result(0, "", data)
175189

176190
return Response(json.dumps(result), mimetype='application/json')
177191

@@ -181,21 +195,23 @@ def individual_image():
181195
mode = request.args.get('run')
182196
tag = request.args.get('tag') # include a index
183197
step_index = int(request.args.get('index')) # index of step
184-
offset = 0
185198

186-
imagefile = try_call(lib.get_invididual_image, log_reader, mode, tag,
187-
step_index)
199+
key = os.path.join('/data/plugin/images/individualImage', mode, tag,
200+
str(step_index))
201+
data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode,
202+
tag, step_index)
188203
response = send_file(
189-
imagefile, as_attachment=True, attachment_filename='img.png')
204+
data, as_attachment=True, attachment_filename='img.png')
190205
return response
191206

192207

193208
@app.route('/data/plugin/histograms/histograms')
194209
def histogram():
195210
run = request.args.get('run')
196211
tag = request.args.get('tag')
197-
result = try_call(lib.get_histogram, log_reader, run, tag)
198-
result = gen_result(0, "", result)
212+
key = os.path.join('/data/plugin/histograms/histograms', run, tag)
213+
data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag)
214+
result = gen_result(0, "", data)
199215
return Response(json.dumps(result), mimetype='application/json')
200216

201217

0 commit comments

Comments
 (0)