Skip to content

Commit e7c1f9e

Browse files
authored
Fix espaloma model download race conditions (#398)
1 parent 06edade commit e7c1f9e

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

openmmforcefields/generators/template_generators.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,21 +1867,24 @@ def _get_model_filepath(self, forcefield):
18671867
return cached_filename
18681868
else:
18691869
# Create the cache directory
1870-
if not os.path.exists(self.ESPALOMA_MODEL_CACHE_PATH):
1871-
os.makedirs(self.ESPALOMA_MODEL_CACHE_PATH)
1870+
os.makedirs(self.ESPALOMA_MODEL_CACHE_PATH, exist_ok=True)
18721871

18731872
# Attempt to retrieve from URL
18741873
_logger.info(f"Attempting to retrieve espaloma model from {url}")
1874+
import tempfile
18751875
import urllib
18761876
import urllib.error
18771877
import urllib.request
18781878

1879-
try:
1880-
urllib.request.urlretrieve(url, filename=cached_filename)
1881-
except urllib.error.URLError:
1882-
raise ValueError(f"No espaloma model found at expected URL: {url}")
1883-
except urllib.error.HTTPError as e:
1884-
raise ValueError(f"An error occurred while retrieving espaloma model from {url} : {e}")
1879+
with tempfile.TemporaryDirectory(dir=self.ESPALOMA_MODEL_CACHE_PATH) as temp_dir:
1880+
temp_filename = os.path.join(temp_dir, filename)
1881+
try:
1882+
urllib.request.urlretrieve(url, filename=temp_filename)
1883+
except urllib.error.URLError:
1884+
raise ValueError(f"No espaloma model found at expected URL: {url}")
1885+
except urllib.error.HTTPError as e:
1886+
raise ValueError(f"An error occurred while retrieving espaloma model from {url} : {e}")
1887+
os.replace(temp_filename, cached_filename)
18851888
return cached_filename
18861889

18871890
@property

0 commit comments

Comments
 (0)