Skip to content

Commit 0aae758

Browse files
Feat: Improve how dask can be configured
1 parent f948fb3 commit 0aae758

File tree

4 files changed

+218
-29
lines changed

4 files changed

+218
-29
lines changed

bluepyparallel/parallel.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Parallel helper."""
2+
import json
23
import logging
34
import multiprocessing
45
import os
56
from abc import abstractmethod
67
from collections.abc import Iterator
8+
from copy import deepcopy
79
from functools import partial
810
from multiprocessing.pool import Pool
911

@@ -33,6 +35,7 @@
3335
except ImportError: # pragma: no cover
3436
ipyparallel_available = False
3537

38+
from bluepyparallel.utils import replace_values_in_docstring
3639

3740
L = logging.getLogger(__name__)
3841

@@ -210,15 +213,78 @@ def shutdown(self):
210213
pass
211214

212215

216+
_DEFAULT_DASK_CONFIG = {
217+
"temporary-directory": None,
218+
"distributed": {
219+
"worker": {
220+
"use_file_locking": False,
221+
"memory": {
222+
"target": False,
223+
"spill": False,
224+
"pause": 0.8,
225+
"terminate": 0.95,
226+
},
227+
"profile": {
228+
"enabled": False,
229+
},
230+
},
231+
"admin": {
232+
"tick": {
233+
"limit": "1h",
234+
},
235+
},
236+
},
237+
}
238+
239+
_DASK_CONFIG_DOCSTRING = """
240+
It is possible to pass a custom dask configuration in several ways.
241+
The simplest way is to pass a dictionary to the `dask_config` argument.
242+
Another way is to create a YAML file containing the configuration and then set the `DASK_CONFIG`
243+
environment variable to its path. Note that this environment variable must be set before `dask`
244+
is imported and can not be updated afterwards.
245+
Also, it is possible to use the `TMPDIR` environment variable to specify the directory in which
246+
the dask internals will be created. Note that this value will be overridden if a dask configuration
247+
is given.
248+
If no config is provided, the following is used:
249+
250+
.. code-block:: JSON
251+
252+
<>
253+
""".replace(
254+
"<>",
255+
json.dumps(
256+
_DEFAULT_DASK_CONFIG,
257+
sort_keys=True,
258+
indent=4,
259+
).replace("\n", "\n" + " " * 4),
260+
)
261+
262+
263+
@replace_values_in_docstring(external_config_block=_DASK_CONFIG_DOCSTRING)
213264
class DaskFactory(ParallelFactory):
214-
"""Parallel helper class using dask."""
265+
"""Parallel helper class using dask.
266+
267+
<external_config_block>
268+
"""
215269

216270
_SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH"
217271

218272
def __init__(
219-
self, batch_size=None, chunk_size=None, scheduler_file=None, address=None, **kwargs
273+
self,
274+
batch_size=None,
275+
chunk_size=None,
276+
scheduler_file=None,
277+
address=None,
278+
dask_config=None,
279+
**kwargs,
220280
):
221281
"""Initialize the dask factory."""
282+
_default_config = deepcopy(_DEFAULT_DASK_CONFIG)
283+
_default_config["temporary-directory"] = os.environ.get("TMPDIR", None)
284+
dask.config.update_defaults(_default_config)
285+
if dask_config is not None: # pragma: no cover
286+
dask.config.set(dask_config)
287+
222288
dask_scheduler_path = scheduler_file or os.getenv(self._SCHEDULER_PATH)
223289
self.interactive = True
224290
if dask_scheduler_path: # pragma: no cover
@@ -269,36 +335,14 @@ def _dask_mapper(in_dask_func, iterable):
269335
return _mapper
270336

271337

338+
@replace_values_in_docstring(external_config_block=_DASK_CONFIG_DOCSTRING)
272339
class DaskDataFrameFactory(DaskFactory):
273-
"""Parallel helper class using dask.dataframe."""
340+
"""Parallel helper class using `dask.dataframe`.
274341
275-
_SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH"
342+
<external_config_block>
343+
"""
276344

277-
def __init__(
278-
self,
279-
batch_size=None,
280-
chunk_size=None,
281-
scheduler_file=None,
282-
address=None,
283-
dask_config=None,
284-
**kwargs,
285-
):
286-
super().__init__(
287-
batch_size, chunk_size, scheduler_file=scheduler_file, address=address, **kwargs
288-
)
289-
if dask_config is None: # pragma: no cover
290-
dask_config = {
291-
"distributed.worker.use_file_locking": False,
292-
"distributed.worker.memory.target": False,
293-
"distributed.worker.memory.spill": False,
294-
"distributed.worker.memory.pause": 0.8,
295-
"distributed.worker.memory.terminate": 95,
296-
"distributed.worker.profile.interval": "10000ms",
297-
"distributed.worker.profile.cycle": "1000000ms",
298-
"distributed.admin.tick.limit": "1h",
299-
}
300-
301-
dask.config.set(dask_config)
345+
_SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH"
302346

303347
def _with_batches(self, *args, **kwargs):
304348
"""Specific process for batches."""

bluepyparallel/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Some utils for the BluePyParallel package."""
2+
3+
4+
def replace_values_in_docstring(**kwargs):
5+
"""Decorator to replace keywords in docstrings by the actual value of a variable.
6+
7+
.. Note::
8+
The keyword must be enclose by <> in the docstring, like <MyKeyword>.
9+
"""
10+
11+
def inner(func):
12+
for k, v in kwargs.items():
13+
func.__doc__ = func.__doc__.replace(f"<{k}>", str(v))
14+
return func
15+
16+
return inner

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
test_reqs = [
3535
"mpi4py>=3.0.1",
36+
"packaging>=20",
3637
"pytest>=6.1",
3738
"pytest-benchmark>=3.4",
3839
"pytest-cov>=3",

tests/test_parallel.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
"""Test the ``bluepyparallel.parallel`` module."""
22
# pylint: disable=missing-function-docstring
33
# pylint: disable=redefined-outer-name
4+
import importlib.metadata
5+
import json
6+
import subprocess
7+
import sys
48
from collections.abc import Iterator
59
from copy import deepcopy
610

711
import pandas as pd
812
import pytest
13+
import yaml
14+
from packaging.version import Version
915

1016
from bluepyparallel import init_parallel_factory
1117
from bluepyparallel.parallel import DaskDataFrameFactory
1218

19+
dask_version = Version(importlib.metadata.version("dask"))
20+
1321

1422
def _evaluation_function_range(element, coeff_a=1.0, coeff_b=1.0):
1523
"""Mock evaluation function."""
@@ -98,3 +106,123 @@ def test_bad_factory_name(self):
98106
"""Test a factory with a wrong name."""
99107
with pytest.raises(KeyError):
100108
init_parallel_factory("UNKNOWN FACTORY")
109+
110+
111+
@pytest.fixture(params=[True, False])
112+
def env_tmpdir(tmpdir, request):
113+
if request.param:
114+
dask_tmpdir = str(tmpdir / "tmpdir_from_env")
115+
yield dask_tmpdir
116+
else:
117+
yield None
118+
119+
120+
@pytest.fixture(params=[True, False])
121+
def env_daskconfig(tmpdir, request):
122+
if request.param:
123+
# Create dask config file
124+
dask_config = {"distributed": {"worker": {"memory": {"pause": 0.123456}}}}
125+
filepath = str(tmpdir / "dask_config.yml")
126+
with open(filepath, "w", encoding="utf-8") as file:
127+
yaml.dump(dask_config, file)
128+
129+
yield filepath
130+
else:
131+
yield None
132+
133+
134+
@pytest.mark.parametrize(
135+
"dask_config",
136+
[
137+
pytest.param(None, id="No Dask config"),
138+
pytest.param({"temporary-directory": "tmpdir"}, id="Dask config with tmp dir"),
139+
],
140+
)
141+
@pytest.mark.parametrize(
142+
"factory_type",
143+
[
144+
pytest.param("dask", id="Dask"),
145+
pytest.param("dask_dataframe", id="Dask-dataframe"),
146+
],
147+
)
148+
@pytest.mark.skipif(dask_version < Version("2023.4"), reason="Requires dask >= 2023.4")
149+
def test_dask_config(tmpdir, dask_config, factory_type, env_tmpdir, env_daskconfig):
150+
"""Test the methods to update dask configuration."""
151+
dask_config_tmpdir = str(tmpdir / "tmpdir_from_config")
152+
has_tmpdir = (
153+
dask_config is not None and dask_config.get("temporary-directory", None) is not None
154+
)
155+
if has_tmpdir:
156+
dask_config["temporary-directory"] = dask_config_tmpdir
157+
dask_config_str = "None" if dask_config is None else json.dumps(dask_config)
158+
159+
# Must test using a subprocess because the DASK_CONFIG environment variable is only considered
160+
# when dask is imported
161+
code = """if True: # This is just to avoid indentation issue
162+
import os
163+
164+
import dask.config
165+
import dask.distributed
166+
167+
from bluepyparallel import init_parallel_factory
168+
from bluepyparallel.parallel import DaskDataFrameFactory
169+
from bluepyparallel.parallel import DaskFactory
170+
171+
172+
dask_cluster = dask.distributed.LocalCluster(dashboard_address=None)
173+
174+
has_tmpdir = {has_tmpdir}
175+
dask_config_tmpdir = "{tmpdir}"
176+
env_tmpdir = os.getenv("TMPDIR", None)
177+
dask_config = {dask_config}
178+
env_daskconfig = os.getenv("DASK_CONFIG", None)
179+
180+
print("Values in subprocess:")
181+
print("tmpdir:", dask_config_tmpdir)
182+
print("has_tmpdir:", has_tmpdir)
183+
print("env_tmpdir:", env_tmpdir)
184+
print("env_daskconfig:", env_daskconfig)
185+
print("dask_config:", dask_config)
186+
187+
factory_kwargs = {{
188+
"address": dask_cluster,
189+
"dask_config": dask_config,
190+
}}
191+
192+
factory = init_parallel_factory("{factory_type}", **factory_kwargs)
193+
194+
print("tmpdir in dask.config:", dask.config.get("temporary-directory"))
195+
print(
196+
"distributed.worker.memory.pause in dask.config:",
197+
dask.config.get("distributed.worker.memory.pause"),
198+
)
199+
200+
if "{factory_type}" == "dask":
201+
assert isinstance(factory, DaskFactory)
202+
else:
203+
assert isinstance(factory, DaskDataFrameFactory)
204+
205+
if env_daskconfig is not None:
206+
assert dask.config.get("distributed.worker.memory.pause") == 0.123456
207+
else:
208+
assert dask.config.get("distributed.worker.memory.pause") == 0.8
209+
210+
if has_tmpdir:
211+
assert dask.config.get("temporary-directory") == dask_config_tmpdir
212+
else:
213+
if env_tmpdir is not None:
214+
assert dask.config.get("temporary-directory") == env_tmpdir
215+
else:
216+
assert dask.config.get("temporary-directory") is None
217+
""".format(
218+
dask_config=dask_config_str,
219+
factory_type=factory_type,
220+
tmpdir=dask_config_tmpdir,
221+
has_tmpdir=has_tmpdir,
222+
)
223+
envs = {}
224+
if env_daskconfig is not None:
225+
envs["DASK_CONFIG"] = env_daskconfig
226+
if env_tmpdir is not None:
227+
envs["TMPDIR"] = env_tmpdir
228+
subprocess.check_call([sys.executable, "-c", code], env=envs)

0 commit comments

Comments
 (0)