Skip to content

Commit b5a2078

Browse files
Add local SubprocessCluster that runs workers in separate processes (#7431)
1 parent 875207b commit b5a2078

File tree

4 files changed

+276
-1
lines changed

4 files changed

+276
-1
lines changed

distributed/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
wait,
2828
)
2929
from distributed.core import Status, connect, rpc
30-
from distributed.deploy import Adaptive, LocalCluster, SpecCluster, SSHCluster
30+
from distributed.deploy import (
31+
Adaptive,
32+
LocalCluster,
33+
SpecCluster,
34+
SSHCluster,
35+
SubprocessCluster,
36+
)
3137
from distributed.diagnostics.plugin import (
3238
CondaInstall,
3339
Environ,
@@ -134,6 +140,7 @@ def _():
134140
"SpecCluster",
135141
"Status",
136142
"Sub",
143+
"SubprocessCluster",
137144
"TimeoutError",
138145
"UploadDirectory",
139146
"UploadFile",

distributed/deploy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from distributed.deploy.local import LocalCluster
88
from distributed.deploy.spec import ProcessInterface, SpecCluster
99
from distributed.deploy.ssh import SSHCluster
10+
from distributed.deploy.subprocess import SubprocessCluster
1011

1112
with suppress(ImportError):
1213
from distributed.deploy.ssh import SSHCluster

distributed/deploy/subprocess.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import copy
5+
import json
6+
import logging
7+
import math
8+
from typing import Any
9+
10+
import psutil
11+
import toolz
12+
13+
from dask.system import CPU_COUNT
14+
15+
from distributed.compatibility import WINDOWS
16+
from distributed.deploy.spec import ProcessInterface, SpecCluster
17+
from distributed.deploy.utils import nprocesses_nthreads
18+
from distributed.scheduler import Scheduler
19+
from distributed.worker_memory import parse_memory_limit
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class SubprocessWorker(ProcessInterface):
25+
"""A local Dask worker running in a dedicated subprocess
26+
27+
Parameters
28+
----------
29+
scheduler:
30+
Address of the scheduler
31+
worker_class:
32+
Python class to use to create the worker, defaults to 'distributed.Nanny'
33+
name:
34+
Name of the worker
35+
worker_kwargs:
36+
Keywords to pass on to the ``Worker`` class constructor
37+
"""
38+
39+
scheduler: str
40+
worker_class: str
41+
worker_kwargs: dict
42+
name: str | None
43+
process: asyncio.subprocess.Process | None
44+
45+
def __init__(
46+
self,
47+
scheduler: str,
48+
worker_class: str = "distributed.Nanny",
49+
name: str | None = None,
50+
worker_kwargs: dict | None = None,
51+
) -> None:
52+
if WINDOWS:
53+
# FIXME: distributed#7434
54+
raise RuntimeError("SubprocessWorker does not support Windows.")
55+
self.scheduler = scheduler
56+
self.worker_class = worker_class
57+
self.name = name
58+
self.worker_kwargs = copy.copy(worker_kwargs or {})
59+
self.process = None
60+
super().__init__()
61+
62+
async def start(self) -> None:
63+
self.process = await asyncio.create_subprocess_exec(
64+
"dask",
65+
"spec",
66+
self.scheduler,
67+
"--spec",
68+
json.dumps({0: {"cls": self.worker_class, "opts": {**self.worker_kwargs}}}),
69+
)
70+
await super().start()
71+
72+
async def close(self) -> None:
73+
if self.process and self.process.returncode is None:
74+
for child in psutil.Process(self.process.pid).children(recursive=True):
75+
child.kill()
76+
self.process.kill()
77+
await self.process.wait()
78+
self.process = None
79+
await super().close()
80+
81+
82+
def SubprocessCluster(
83+
host: str | None = None,
84+
scheduler_port: int = 0,
85+
scheduler_kwargs: dict | None = None,
86+
dashboard_address: str = ":8787",
87+
worker_class: str = "distributed.Nanny",
88+
n_workers: int | None = None,
89+
threads_per_worker: int | None = None,
90+
worker_kwargs: dict | None = None,
91+
silence_logs: int = logging.WARN,
92+
**kwargs: Any,
93+
) -> SpecCluster:
94+
"""Create in-process scheduler and workers running in dedicated subprocesses
95+
96+
This creates a "cluster" of a scheduler running in the current process and
97+
workers running in dedicated subprocesses.
98+
99+
.. warning::
100+
101+
This function is experimental
102+
103+
Parameters
104+
----------
105+
host:
106+
Host address on which the scheduler will listen, defaults to localhost
107+
scheduler_port:
108+
Port fo the scheduler, defaults to 0 to choose a random port
109+
scheduler_kwargs:
110+
Keywords to pass on to scheduler
111+
dashboard_address:
112+
Address on which to listen for the Bokeh diagnostics server like
113+
'localhost:8787' or '0.0.0.0:8787', defaults to ':8787'
114+
115+
Set to ``None`` to disable the dashboard.
116+
Use ':0' for a random port.
117+
worker_class:
118+
Worker class to instantiate workers from, defaults to 'distributed.Nanny'
119+
n_workers:
120+
Number of workers to start
121+
threads:
122+
Number of threads per each worker
123+
worker_kwargs:
124+
Keywords to pass on to the ``Worker`` class constructor
125+
silence_logs:
126+
Level of logs to print out to stdout, defaults to ``logging.WARN``
127+
128+
Use a falsy value like False or None to disable log silencing.
129+
130+
Examples
131+
--------
132+
>>> cluster = SubprocessCluster() # Create a subprocess cluster #doctest: +SKIP
133+
>>> cluster # doctest: +SKIP
134+
SubprocessCluster(SubprocessCluster, 'tcp://127.0.0.1:61207', workers=5, threads=10, memory=16.00 GiB)
135+
136+
>>> c = Client(cluster) # connect to subprocess cluster # doctest: +SKIP
137+
138+
Scale the cluster to three workers
139+
140+
>>> cluster.scale(3) # doctest: +SKIP
141+
"""
142+
if WINDOWS:
143+
# FIXME: distributed#7434
144+
raise RuntimeError("SubprocessCluster does not support Windows.")
145+
if not host:
146+
host = "127.0.0.1"
147+
worker_kwargs = worker_kwargs or {}
148+
scheduler_kwargs = scheduler_kwargs or {}
149+
150+
if n_workers is None and threads_per_worker is None:
151+
n_workers, threads_per_worker = nprocesses_nthreads()
152+
if n_workers is None and threads_per_worker is not None:
153+
n_workers = max(1, CPU_COUNT // threads_per_worker)
154+
if n_workers and threads_per_worker is None:
155+
# Overcommit threads per worker, rather than undercommit
156+
threads_per_worker = max(1, int(math.ceil(CPU_COUNT / n_workers)))
157+
if n_workers and "memory_limit" not in worker_kwargs:
158+
worker_kwargs["memory_limit"] = parse_memory_limit(
159+
"auto", 1, n_workers, logger=logger
160+
)
161+
assert n_workers is not None
162+
163+
scheduler_kwargs = toolz.merge(
164+
{
165+
"host": host,
166+
"port": scheduler_port,
167+
"dashboard": dashboard_address is not None,
168+
"dashboard_address": dashboard_address,
169+
},
170+
scheduler_kwargs,
171+
)
172+
worker_kwargs = toolz.merge(
173+
{
174+
"host": host,
175+
"nthreads": threads_per_worker,
176+
"silence_logs": silence_logs,
177+
},
178+
worker_kwargs,
179+
)
180+
181+
scheduler = {"cls": Scheduler, "options": scheduler_kwargs}
182+
worker = {
183+
"cls": SubprocessWorker,
184+
"options": {"worker_class": worker_class, "worker_kwargs": worker_kwargs},
185+
}
186+
workers = {i: worker for i in range(n_workers)}
187+
return SpecCluster(
188+
workers=workers,
189+
scheduler=scheduler,
190+
worker=worker,
191+
name="SubprocessCluster",
192+
silence_logs=silence_logs,
193+
**kwargs,
194+
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from distributed import Client
6+
from distributed.compatibility import WINDOWS
7+
from distributed.deploy.subprocess import SubprocessCluster, SubprocessWorker
8+
from distributed.utils_test import gen_test
9+
10+
11+
@pytest.mark.skipif(WINDOWS, reason="distributed#7434")
12+
@gen_test()
13+
async def test_basic():
14+
async with SubprocessCluster(
15+
asynchronous=True,
16+
dashboard_address=":0",
17+
scheduler_kwargs={"idle_timeout": "5s"},
18+
worker_kwargs={"death_timeout": "5s"},
19+
) as cluster:
20+
async with Client(cluster, asynchronous=True) as client:
21+
result = await client.submit(lambda x: x + 1, 10)
22+
assert result == 11
23+
assert cluster._supports_scaling
24+
assert "Subprocess" in repr(cluster)
25+
26+
27+
@pytest.mark.skipif(WINDOWS, reason="distributed#7434")
28+
@gen_test()
29+
async def test_n_workers():
30+
async with SubprocessCluster(
31+
asynchronous=True, dashboard_address=":0", n_workers=2
32+
) as cluster:
33+
async with Client(cluster, asynchronous=True) as client:
34+
assert len(cluster.workers) == 2
35+
result = await client.submit(lambda x: x + 1, 10)
36+
assert result == 11
37+
assert cluster._supports_scaling
38+
assert "Subprocess" in repr(cluster)
39+
40+
41+
@pytest.mark.skipif(WINDOWS, reason="distributed#7434")
42+
@gen_test()
43+
async def test_scale_up_and_down():
44+
async with SubprocessCluster(
45+
n_workers=0,
46+
silence_logs=False,
47+
dashboard_address=":0",
48+
asynchronous=True,
49+
) as cluster:
50+
async with Client(cluster, asynchronous=True) as c:
51+
52+
assert not cluster.workers
53+
54+
cluster.scale(2)
55+
await c.wait_for_workers(2)
56+
assert len(cluster.workers) == 2
57+
assert len(cluster.scheduler.workers) == 2
58+
59+
cluster.scale(1)
60+
await cluster
61+
62+
assert len(cluster.workers) == 1
63+
64+
65+
@pytest.mark.skipif(
66+
not WINDOWS, reason="Windows-specific error testing (distributed#7434)"
67+
)
68+
def test_raise_on_windows():
69+
with pytest.raises(RuntimeError, match="not support Windows"):
70+
SubprocessCluster()
71+
72+
with pytest.raises(RuntimeError, match="not support Windows"):
73+
SubprocessWorker(scheduler="tcp://127.0.0.1:8786")

0 commit comments

Comments
 (0)