|
1 | 1 | """Test the ``bluepyparallel.parallel`` module.""" |
2 | 2 | # pylint: disable=missing-function-docstring |
3 | 3 | # pylint: disable=redefined-outer-name |
| 4 | +import importlib.metadata |
| 5 | +import json |
| 6 | +import subprocess |
| 7 | +import sys |
4 | 8 | from collections.abc import Iterator |
5 | 9 | from copy import deepcopy |
6 | 10 |
|
7 | 11 | import pandas as pd |
8 | 12 | import pytest |
| 13 | +import yaml |
| 14 | +from packaging.version import Version |
9 | 15 |
|
10 | 16 | from bluepyparallel import init_parallel_factory |
11 | 17 | from bluepyparallel.parallel import DaskDataFrameFactory |
12 | 18 |
|
| 19 | +dask_version = Version(importlib.metadata.version("dask")) |
| 20 | + |
13 | 21 |
|
14 | 22 | def _evaluation_function_range(element, coeff_a=1.0, coeff_b=1.0): |
15 | 23 | """Mock evaluation function.""" |
@@ -98,3 +106,123 @@ def test_bad_factory_name(self): |
98 | 106 | """Test a factory with a wrong name.""" |
99 | 107 | with pytest.raises(KeyError): |
100 | 108 | init_parallel_factory("UNKNOWN FACTORY") |
| 109 | + |
| 110 | + |
| 111 | +@pytest.fixture(params=[True, False]) |
| 112 | +def env_tmpdir(tmpdir, request): |
| 113 | + if request.param: |
| 114 | + dask_tmpdir = str(tmpdir / "tmpdir_from_env") |
| 115 | + yield dask_tmpdir |
| 116 | + else: |
| 117 | + yield None |
| 118 | + |
| 119 | + |
| 120 | +@pytest.fixture(params=[True, False]) |
| 121 | +def env_daskconfig(tmpdir, request): |
| 122 | + if request.param: |
| 123 | + # Create dask config file |
| 124 | + dask_config = {"distributed": {"worker": {"memory": {"pause": 0.123456}}}} |
| 125 | + filepath = str(tmpdir / "dask_config.yml") |
| 126 | + with open(filepath, "w", encoding="utf-8") as file: |
| 127 | + yaml.dump(dask_config, file) |
| 128 | + |
| 129 | + yield filepath |
| 130 | + else: |
| 131 | + yield None |
| 132 | + |
| 133 | + |
| 134 | +@pytest.mark.parametrize( |
| 135 | + "dask_config", |
| 136 | + [ |
| 137 | + pytest.param(None, id="No Dask config"), |
| 138 | + pytest.param({"temporary-directory": "tmpdir"}, id="Dask config with tmp dir"), |
| 139 | + ], |
| 140 | +) |
| 141 | +@pytest.mark.parametrize( |
| 142 | + "factory_type", |
| 143 | + [ |
| 144 | + pytest.param("dask", id="Dask"), |
| 145 | + pytest.param("dask_dataframe", id="Dask-dataframe"), |
| 146 | + ], |
| 147 | +) |
| 148 | +@pytest.mark.skipif(dask_version < Version("2023.4"), reason="Requires dask >= 2023.4") |
| 149 | +def test_dask_config(tmpdir, dask_config, factory_type, env_tmpdir, env_daskconfig): |
| 150 | + """Test the methods to update dask configuration.""" |
| 151 | + dask_config_tmpdir = str(tmpdir / "tmpdir_from_config") |
| 152 | + has_tmpdir = ( |
| 153 | + dask_config is not None and dask_config.get("temporary-directory", None) is not None |
| 154 | + ) |
| 155 | + if has_tmpdir: |
| 156 | + dask_config["temporary-directory"] = dask_config_tmpdir |
| 157 | + dask_config_str = "None" if dask_config is None else json.dumps(dask_config) |
| 158 | + |
| 159 | + # Must test using a subprocess because the DASK_CONFIG environment variable is only considered |
| 160 | + # when dask is imported |
| 161 | + code = """if True: # This is just to avoid indentation issue |
| 162 | + import os |
| 163 | +
|
| 164 | + import dask.config |
| 165 | + import dask.distributed |
| 166 | +
|
| 167 | + from bluepyparallel import init_parallel_factory |
| 168 | + from bluepyparallel.parallel import DaskDataFrameFactory |
| 169 | + from bluepyparallel.parallel import DaskFactory |
| 170 | +
|
| 171 | +
|
| 172 | + dask_cluster = dask.distributed.LocalCluster(dashboard_address=None) |
| 173 | +
|
| 174 | + has_tmpdir = {has_tmpdir} |
| 175 | + dask_config_tmpdir = "{tmpdir}" |
| 176 | + env_tmpdir = os.getenv("TMPDIR", None) |
| 177 | + dask_config = {dask_config} |
| 178 | + env_daskconfig = os.getenv("DASK_CONFIG", None) |
| 179 | +
|
| 180 | + print("Values in subprocess:") |
| 181 | + print("tmpdir:", dask_config_tmpdir) |
| 182 | + print("has_tmpdir:", has_tmpdir) |
| 183 | + print("env_tmpdir:", env_tmpdir) |
| 184 | + print("env_daskconfig:", env_daskconfig) |
| 185 | + print("dask_config:", dask_config) |
| 186 | +
|
| 187 | + factory_kwargs = {{ |
| 188 | + "address": dask_cluster, |
| 189 | + "dask_config": dask_config, |
| 190 | + }} |
| 191 | +
|
| 192 | + factory = init_parallel_factory("{factory_type}", **factory_kwargs) |
| 193 | +
|
| 194 | + print("tmpdir in dask.config:", dask.config.get("temporary-directory")) |
| 195 | + print( |
| 196 | + "distributed.worker.memory.pause in dask.config:", |
| 197 | + dask.config.get("distributed.worker.memory.pause"), |
| 198 | + ) |
| 199 | +
|
| 200 | + if "{factory_type}" == "dask": |
| 201 | + assert isinstance(factory, DaskFactory) |
| 202 | + else: |
| 203 | + assert isinstance(factory, DaskDataFrameFactory) |
| 204 | +
|
| 205 | + if env_daskconfig is not None: |
| 206 | + assert dask.config.get("distributed.worker.memory.pause") == 0.123456 |
| 207 | + else: |
| 208 | + assert dask.config.get("distributed.worker.memory.pause") == 0.8 |
| 209 | +
|
| 210 | + if has_tmpdir: |
| 211 | + assert dask.config.get("temporary-directory") == dask_config_tmpdir |
| 212 | + else: |
| 213 | + if env_tmpdir is not None: |
| 214 | + assert dask.config.get("temporary-directory") == env_tmpdir |
| 215 | + else: |
| 216 | + assert dask.config.get("temporary-directory") is None |
| 217 | + """.format( |
| 218 | + dask_config=dask_config_str, |
| 219 | + factory_type=factory_type, |
| 220 | + tmpdir=dask_config_tmpdir, |
| 221 | + has_tmpdir=has_tmpdir, |
| 222 | + ) |
| 223 | + envs = {} |
| 224 | + if env_daskconfig is not None: |
| 225 | + envs["DASK_CONFIG"] = env_daskconfig |
| 226 | + if env_tmpdir is not None: |
| 227 | + envs["TMPDIR"] = env_tmpdir |
| 228 | + subprocess.check_call([sys.executable, "-c", code], env=envs) |
0 commit comments