|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import copy |
| 5 | +import json |
| 6 | +import logging |
| 7 | +import math |
| 8 | +from typing import Any |
| 9 | + |
| 10 | +import psutil |
| 11 | +import toolz |
| 12 | + |
| 13 | +from dask.system import CPU_COUNT |
| 14 | + |
| 15 | +from distributed.compatibility import WINDOWS |
| 16 | +from distributed.deploy.spec import ProcessInterface, SpecCluster |
| 17 | +from distributed.deploy.utils import nprocesses_nthreads |
| 18 | +from distributed.scheduler import Scheduler |
| 19 | +from distributed.worker_memory import parse_memory_limit |
| 20 | + |
| 21 | +logger = logging.getLogger(__name__) |
| 22 | + |
| 23 | + |
| 24 | +class SubprocessWorker(ProcessInterface): |
| 25 | + """A local Dask worker running in a dedicated subprocess |
| 26 | +
|
| 27 | + Parameters |
| 28 | + ---------- |
| 29 | + scheduler: |
| 30 | + Address of the scheduler |
| 31 | + worker_class: |
| 32 | + Python class to use to create the worker, defaults to 'distributed.Nanny' |
| 33 | + name: |
| 34 | + Name of the worker |
| 35 | + worker_kwargs: |
| 36 | + Keywords to pass on to the ``Worker`` class constructor |
| 37 | + """ |
| 38 | + |
| 39 | + scheduler: str |
| 40 | + worker_class: str |
| 41 | + worker_kwargs: dict |
| 42 | + name: str | None |
| 43 | + process: asyncio.subprocess.Process | None |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + scheduler: str, |
| 48 | + worker_class: str = "distributed.Nanny", |
| 49 | + name: str | None = None, |
| 50 | + worker_kwargs: dict | None = None, |
| 51 | + ) -> None: |
| 52 | + if WINDOWS: |
| 53 | + # FIXME: distributed#7434 |
| 54 | + raise RuntimeError("SubprocessWorker does not support Windows.") |
| 55 | + self.scheduler = scheduler |
| 56 | + self.worker_class = worker_class |
| 57 | + self.name = name |
| 58 | + self.worker_kwargs = copy.copy(worker_kwargs or {}) |
| 59 | + self.process = None |
| 60 | + super().__init__() |
| 61 | + |
| 62 | + async def start(self) -> None: |
| 63 | + self.process = await asyncio.create_subprocess_exec( |
| 64 | + "dask", |
| 65 | + "spec", |
| 66 | + self.scheduler, |
| 67 | + "--spec", |
| 68 | + json.dumps({0: {"cls": self.worker_class, "opts": {**self.worker_kwargs}}}), |
| 69 | + ) |
| 70 | + await super().start() |
| 71 | + |
| 72 | + async def close(self) -> None: |
| 73 | + if self.process and self.process.returncode is None: |
| 74 | + for child in psutil.Process(self.process.pid).children(recursive=True): |
| 75 | + child.kill() |
| 76 | + self.process.kill() |
| 77 | + await self.process.wait() |
| 78 | + self.process = None |
| 79 | + await super().close() |
| 80 | + |
| 81 | + |
| 82 | +def SubprocessCluster( |
| 83 | + host: str | None = None, |
| 84 | + scheduler_port: int = 0, |
| 85 | + scheduler_kwargs: dict | None = None, |
| 86 | + dashboard_address: str = ":8787", |
| 87 | + worker_class: str = "distributed.Nanny", |
| 88 | + n_workers: int | None = None, |
| 89 | + threads_per_worker: int | None = None, |
| 90 | + worker_kwargs: dict | None = None, |
| 91 | + silence_logs: int = logging.WARN, |
| 92 | + **kwargs: Any, |
| 93 | +) -> SpecCluster: |
| 94 | + """Create in-process scheduler and workers running in dedicated subprocesses |
| 95 | +
|
| 96 | + This creates a "cluster" of a scheduler running in the current process and |
| 97 | + workers running in dedicated subprocesses. |
| 98 | +
|
| 99 | + .. warning:: |
| 100 | +
|
| 101 | + This function is experimental |
| 102 | +
|
| 103 | + Parameters |
| 104 | + ---------- |
| 105 | + host: |
| 106 | + Host address on which the scheduler will listen, defaults to localhost |
| 107 | + scheduler_port: |
| 108 | + Port fo the scheduler, defaults to 0 to choose a random port |
| 109 | + scheduler_kwargs: |
| 110 | + Keywords to pass on to scheduler |
| 111 | + dashboard_address: |
| 112 | + Address on which to listen for the Bokeh diagnostics server like |
| 113 | + 'localhost:8787' or '0.0.0.0:8787', defaults to ':8787' |
| 114 | +
|
| 115 | + Set to ``None`` to disable the dashboard. |
| 116 | + Use ':0' for a random port. |
| 117 | + worker_class: |
| 118 | + Worker class to instantiate workers from, defaults to 'distributed.Nanny' |
| 119 | + n_workers: |
| 120 | + Number of workers to start |
| 121 | + threads: |
| 122 | + Number of threads per each worker |
| 123 | + worker_kwargs: |
| 124 | + Keywords to pass on to the ``Worker`` class constructor |
| 125 | + silence_logs: |
| 126 | + Level of logs to print out to stdout, defaults to ``logging.WARN`` |
| 127 | +
|
| 128 | + Use a falsy value like False or None to disable log silencing. |
| 129 | +
|
| 130 | + Examples |
| 131 | + -------- |
| 132 | + >>> cluster = SubprocessCluster() # Create a subprocess cluster #doctest: +SKIP |
| 133 | + >>> cluster # doctest: +SKIP |
| 134 | + SubprocessCluster(SubprocessCluster, 'tcp://127.0.0.1:61207', workers=5, threads=10, memory=16.00 GiB) |
| 135 | +
|
| 136 | + >>> c = Client(cluster) # connect to subprocess cluster # doctest: +SKIP |
| 137 | +
|
| 138 | + Scale the cluster to three workers |
| 139 | +
|
| 140 | + >>> cluster.scale(3) # doctest: +SKIP |
| 141 | + """ |
| 142 | + if WINDOWS: |
| 143 | + # FIXME: distributed#7434 |
| 144 | + raise RuntimeError("SubprocessCluster does not support Windows.") |
| 145 | + if not host: |
| 146 | + host = "127.0.0.1" |
| 147 | + worker_kwargs = worker_kwargs or {} |
| 148 | + scheduler_kwargs = scheduler_kwargs or {} |
| 149 | + |
| 150 | + if n_workers is None and threads_per_worker is None: |
| 151 | + n_workers, threads_per_worker = nprocesses_nthreads() |
| 152 | + if n_workers is None and threads_per_worker is not None: |
| 153 | + n_workers = max(1, CPU_COUNT // threads_per_worker) |
| 154 | + if n_workers and threads_per_worker is None: |
| 155 | + # Overcommit threads per worker, rather than undercommit |
| 156 | + threads_per_worker = max(1, int(math.ceil(CPU_COUNT / n_workers))) |
| 157 | + if n_workers and "memory_limit" not in worker_kwargs: |
| 158 | + worker_kwargs["memory_limit"] = parse_memory_limit( |
| 159 | + "auto", 1, n_workers, logger=logger |
| 160 | + ) |
| 161 | + assert n_workers is not None |
| 162 | + |
| 163 | + scheduler_kwargs = toolz.merge( |
| 164 | + { |
| 165 | + "host": host, |
| 166 | + "port": scheduler_port, |
| 167 | + "dashboard": dashboard_address is not None, |
| 168 | + "dashboard_address": dashboard_address, |
| 169 | + }, |
| 170 | + scheduler_kwargs, |
| 171 | + ) |
| 172 | + worker_kwargs = toolz.merge( |
| 173 | + { |
| 174 | + "host": host, |
| 175 | + "nthreads": threads_per_worker, |
| 176 | + "silence_logs": silence_logs, |
| 177 | + }, |
| 178 | + worker_kwargs, |
| 179 | + ) |
| 180 | + |
| 181 | + scheduler = {"cls": Scheduler, "options": scheduler_kwargs} |
| 182 | + worker = { |
| 183 | + "cls": SubprocessWorker, |
| 184 | + "options": {"worker_class": worker_class, "worker_kwargs": worker_kwargs}, |
| 185 | + } |
| 186 | + workers = {i: worker for i in range(n_workers)} |
| 187 | + return SpecCluster( |
| 188 | + workers=workers, |
| 189 | + scheduler=scheduler, |
| 190 | + worker=worker, |
| 191 | + name="SubprocessCluster", |
| 192 | + silence_logs=silence_logs, |
| 193 | + **kwargs, |
| 194 | + ) |
0 commit comments