Skip to content

Commit 2115f48

Browse files
committed
ensure atomic writes when saving a file
Right now, if a program crashes in the middle of saving, you lose all your data.
1 parent 2b06b62 commit 2115f48

File tree

5 files changed

+10
-19
lines changed

5 files changed

+10
-19
lines changed

adaptive/learner/base_learner.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22

33
import abc
4-
import os
54
from contextlib import suppress
65
from copy import deepcopy
76

@@ -175,16 +174,8 @@ def save(self, fname, compress=True):
175174
using compression, one must load it with compression too.
176175
"""
177176
data = self._get_data()
178-
179-
backup = os.path.join(".", fname)
180-
if os.path.exists(fname):
181-
os.rename(backup)
182-
183177
save(fname, data, compress)
184178

185-
if os.path.exists(backup):
186-
os.remove(backup)
187-
188179
def load(self, fname, compress=True):
189180
"""Load the data of a learner from a pickle file.
190181
@@ -196,12 +187,6 @@ def load(self, fname, compress=True):
196187
If the data is compressed when saved, one must load it
197188
with compression too.
198189
"""
199-
with suppress(FileNotFoundError):
200-
backup = os.path.join(".", fname)
201-
if os.path.getctime(fname) < os.path.getctime(backup):
202-
# the backup is the newer file
203-
fname = backup
204-
205190
with suppress(FileNotFoundError, EOFError):
206191
data = load(fname, compress)
207192
self._set_data(data)

adaptive/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from contextlib import contextmanager
88
from itertools import product
99

10+
from atomicwrites import AtomicWriter
1011

1112
def named_product(**items):
1213
names = items.keys()
@@ -44,10 +45,13 @@ def save(fname, data, compress=True):
4445
dirname = os.path.dirname(fname)
4546
if dirname:
4647
os.makedirs(dirname, exist_ok=True)
47-
_open = gzip.open if compress else open
48-
with _open(fname, "wb") as f:
49-
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
5048

49+
blob = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
50+
if compress:
51+
blob = gzip.compress(blob)
52+
53+
with AtomicWriter(fname, 'wb', overwrite=True).open() as f:
54+
f.write(blob)
5155

5256
def load(fname, compress=True):
5357
fname = os.path.expanduser(fname)

docs/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies:
1212
- bokeh=1.0.4
1313
- plotly=3.9.0
1414
- ipywidgets=7.4.2
15+
- atomicwrites=1.3.0
1516
- pip:
1617
- git+https://github.com/basnijholt/jupyter-sphinx.git@widgets_execute
1718
- sphinx_fontawesome==0.0.6

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ dependencies:
1616
- ipywidgets
1717
- scikit-optimize
1818
- plotly
19+
- atomicwrites

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_version_and_cmdclass(package_name):
2424
version, cmdclass = get_version_and_cmdclass("adaptive")
2525

2626

27-
install_requires = ["scipy", "sortedcollections >= 1.1", "sortedcontainers >= 2.0"]
27+
install_requires = ["scipy", "sortedcollections >= 1.1", "sortedcontainers >= 2.0", "atomicwrites"]
2828

2929
extras_require = {
3030
"notebook": [

0 commit comments

Comments
 (0)